mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-09 00:33:58 +01:00
Merge remote-tracking branch 'origin/master' into holepunch-remote-allow-list
This commit is contained in:
commit
97f0abdc55
2
.github/workflows/gofmt.yml
vendored
2
.github/workflows/gofmt.yml
vendored
@ -18,7 +18,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version: '1.23'
|
||||
check-latest: true
|
||||
|
||||
- name: Install goimports
|
||||
|
||||
6
.github/workflows/release.yml
vendored
6
.github/workflows/release.yml
vendored
@ -14,7 +14,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version: '1.23'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@ -37,7 +37,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version: '1.23'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@ -70,7 +70,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version: '1.23'
|
||||
check-latest: true
|
||||
|
||||
- name: Import certificates
|
||||
|
||||
3
.github/workflows/smoke-extra.yml
vendored
3
.github/workflows/smoke-extra.yml
vendored
@ -27,6 +27,9 @@ jobs:
|
||||
go-version-file: 'go.mod'
|
||||
check-latest: true
|
||||
|
||||
- name: add hashicorp source
|
||||
run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
|
||||
|
||||
- name: install vagrant
|
||||
run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox
|
||||
|
||||
|
||||
2
.github/workflows/smoke.yml
vendored
2
.github/workflows/smoke.yml
vendored
@ -22,7 +22,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version: '1.23'
|
||||
check-latest: true
|
||||
|
||||
- name: build
|
||||
|
||||
10
.github/workflows/smoke/build.sh
vendored
10
.github/workflows/smoke/build.sh
vendored
@ -5,6 +5,10 @@ set -e -x
|
||||
rm -rf ./build
|
||||
mkdir ./build
|
||||
|
||||
# TODO: Assumes your docker bridge network is a /24, and the first container that launches will be .1
|
||||
# - We could make this better by launching the lighthouse first and then fetching what IP it is.
|
||||
NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ end }}' | cut -d. -f1-3)"
|
||||
|
||||
(
|
||||
cd build
|
||||
|
||||
@ -21,16 +25,16 @@ mkdir ./build
|
||||
../genconfig.sh >lighthouse1.yml
|
||||
|
||||
HOST="host2" \
|
||||
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
|
||||
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
|
||||
../genconfig.sh >host2.yml
|
||||
|
||||
HOST="host3" \
|
||||
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
|
||||
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
|
||||
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
||||
../genconfig.sh >host3.yml
|
||||
|
||||
HOST="host4" \
|
||||
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
|
||||
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
|
||||
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
||||
../genconfig.sh >host4.yml
|
||||
|
||||
|
||||
34
.github/workflows/smoke/smoke-vagrant.sh
vendored
34
.github/workflows/smoke/smoke-vagrant.sh
vendored
@ -29,13 +29,13 @@ docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
|
||||
docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
|
||||
|
||||
vagrant up
|
||||
vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test"
|
||||
vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T
|
||||
|
||||
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
||||
sleep 1
|
||||
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
||||
sleep 1
|
||||
vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" &
|
||||
vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' &
|
||||
sleep 15
|
||||
|
||||
# grab tcpdump pcaps for debugging
|
||||
@ -46,8 +46,8 @@ docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host
|
||||
# vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap &
|
||||
# vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap &
|
||||
|
||||
docker exec host2 ncat -nklv 0.0.0.0 2000 &
|
||||
vagrant ssh -c "ncat -nklv 0.0.0.0 2000" &
|
||||
#docker exec host2 ncat -nklv 0.0.0.0 2000 &
|
||||
#vagrant ssh -c "ncat -nklv 0.0.0.0 2000" &
|
||||
#docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
|
||||
#vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" &
|
||||
|
||||
@ -68,11 +68,11 @@ docker exec host2 ping -c1 192.168.100.1
|
||||
# Should fail because not allowed by host3 inbound firewall
|
||||
! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
|
||||
|
||||
set +x
|
||||
echo
|
||||
echo " *** Testing ncat from host2"
|
||||
echo
|
||||
set -x
|
||||
#set +x
|
||||
#echo
|
||||
#echo " *** Testing ncat from host2"
|
||||
#echo
|
||||
#set -x
|
||||
# Should fail because not allowed by host3 inbound firewall
|
||||
#! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
|
||||
#! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
|
||||
@ -82,18 +82,18 @@ echo
|
||||
echo " *** Testing ping from host3"
|
||||
echo
|
||||
set -x
|
||||
vagrant ssh -c "ping -c1 192.168.100.1"
|
||||
vagrant ssh -c "ping -c1 192.168.100.2"
|
||||
vagrant ssh -c "ping -c1 192.168.100.1" -- -T
|
||||
vagrant ssh -c "ping -c1 192.168.100.2" -- -T
|
||||
|
||||
set +x
|
||||
echo
|
||||
echo " *** Testing ncat from host3"
|
||||
echo
|
||||
set -x
|
||||
#set +x
|
||||
#echo
|
||||
#echo " *** Testing ncat from host3"
|
||||
#echo
|
||||
#set -x
|
||||
#vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000"
|
||||
#vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2
|
||||
|
||||
vagrant ssh -c "sudo xargs kill </nebula/pid"
|
||||
vagrant ssh -c "sudo xargs kill </nebula/pid" -- -T
|
||||
docker exec host2 sh -c 'kill 1'
|
||||
docker exec lighthouse1 sh -c 'kill 1'
|
||||
sleep 1
|
||||
|
||||
18
.github/workflows/test.yml
vendored
18
.github/workflows/test.yml
vendored
@ -22,7 +22,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version: '1.23'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@ -31,6 +31,11 @@ jobs:
|
||||
- name: Vet
|
||||
run: make vet
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v6
|
||||
with:
|
||||
version: v1.64
|
||||
|
||||
- name: Test
|
||||
run: make test
|
||||
|
||||
@ -55,7 +60,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version: '1.23'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@ -65,7 +70,7 @@ jobs:
|
||||
run: make test-boringcrypto
|
||||
|
||||
- name: End 2 end
|
||||
run: make e2evv GOEXPERIMENT=boringcrypto CGO_ENABLED=1
|
||||
run: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0"
|
||||
|
||||
test-linux-pkcs11:
|
||||
name: Build and test on linux with pkcs11
|
||||
@ -97,7 +102,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version: '1.23'
|
||||
check-latest: true
|
||||
|
||||
- name: Build nebula
|
||||
@ -109,6 +114,11 @@ jobs:
|
||||
- name: Vet
|
||||
run: make vet
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v6
|
||||
with:
|
||||
version: v1.64
|
||||
|
||||
- name: Test
|
||||
run: make test
|
||||
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -5,7 +5,8 @@
|
||||
/nebula-darwin
|
||||
/nebula.exe
|
||||
/nebula-cert.exe
|
||||
/coverage.out
|
||||
**/coverage.out
|
||||
**/cover.out
|
||||
/cpu.pprof
|
||||
/build
|
||||
/*.tar.gz
|
||||
|
||||
9
.golangci.yaml
Normal file
9
.golangci.yaml
Normal file
@ -0,0 +1,9 @@
|
||||
# yaml-language-server: $schema=https://golangci-lint.run/jsonschema/golangci.jsonschema.json
|
||||
linters:
|
||||
# Disable all linters.
|
||||
# Default: false
|
||||
disable-all: true
|
||||
# Enable specific linter
|
||||
# https://golangci-lint.run/usage/linters/#enabled-by-default
|
||||
enable:
|
||||
- testifylint
|
||||
6
Makefile
6
Makefile
@ -137,6 +137,8 @@ build/linux-mips-softfloat/%: LDFLAGS += -s -w
|
||||
# boringcrypto
|
||||
build/linux-amd64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
|
||||
build/linux-arm64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
|
||||
build/linux-amd64-boringcrypto/%: LDFLAGS += -checklinkname=0
|
||||
build/linux-arm64-boringcrypto/%: LDFLAGS += -checklinkname=0
|
||||
|
||||
build/%/nebula: .FORCE
|
||||
GOOS=$(firstword $(subst -, , $*)) \
|
||||
@ -170,7 +172,7 @@ test:
|
||||
go test -v ./...
|
||||
|
||||
test-boringcrypto:
|
||||
GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -v ./...
|
||||
GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -ldflags "-checklinkname=0" -v ./...
|
||||
|
||||
test-pkcs11:
|
||||
CGO_ENABLED=1 go test -v -tags pkcs11 ./...
|
||||
@ -196,7 +198,7 @@ bench-cpu-long:
|
||||
go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof
|
||||
go tool pprof go-audit.test cpu.pprof
|
||||
|
||||
proto: nebula.pb.go cert/cert.pb.go
|
||||
proto: nebula.pb.go cert/cert_v1.pb.go
|
||||
|
||||
nebula.pb.go: nebula.proto .FORCE
|
||||
go build github.com/gogo/protobuf/protoc-gen-gogofaster
|
||||
|
||||
@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/).
|
||||
|
||||
You can read more about Nebula [here](https://medium.com/p/884110a5579).
|
||||
|
||||
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU).
|
||||
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ).
|
||||
|
||||
## Supported Platforms
|
||||
|
||||
@ -47,7 +47,7 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
|
||||
$ sudo apk add nebula
|
||||
```
|
||||
|
||||
- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/nebula.rb)
|
||||
- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/n/nebula.rb)
|
||||
```
|
||||
$ brew install nebula
|
||||
```
|
||||
|
||||
@ -128,7 +128,6 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
|
||||
|
||||
ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits())
|
||||
|
||||
// TODO: should we error on duplicate CIDRs in the config?
|
||||
tree.Insert(ipNet, value)
|
||||
|
||||
maskBits := ipNet.Bits()
|
||||
@ -251,20 +250,20 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error
|
||||
return remoteAllowRanges, nil
|
||||
}
|
||||
|
||||
func (al *AllowList) Allow(ip netip.Addr) bool {
|
||||
func (al *AllowList) Allow(addr netip.Addr) bool {
|
||||
if al == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
result, _ := al.cidrTree.Lookup(ip)
|
||||
result, _ := al.cidrTree.Lookup(addr)
|
||||
return result
|
||||
}
|
||||
|
||||
func (al *LocalAllowList) Allow(ip netip.Addr) bool {
|
||||
func (al *LocalAllowList) Allow(udpAddr netip.Addr) bool {
|
||||
if al == nil {
|
||||
return true
|
||||
}
|
||||
return al.AllowList.Allow(ip)
|
||||
return al.AllowList.Allow(udpAddr)
|
||||
}
|
||||
|
||||
func (al *LocalAllowList) AllowName(name string) bool {
|
||||
@ -282,23 +281,37 @@ func (al *LocalAllowList) AllowName(name string) bool {
|
||||
return !al.nameRules[0].Allow
|
||||
}
|
||||
|
||||
func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool {
|
||||
func (al *RemoteAllowList) AllowUnknownVpnAddr(vpnAddr netip.Addr) bool {
|
||||
if al == nil {
|
||||
return true
|
||||
}
|
||||
return al.AllowList.Allow(ip)
|
||||
return al.AllowList.Allow(vpnAddr)
|
||||
}
|
||||
|
||||
func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool {
|
||||
if !al.getInsideAllowList(vpnIp).Allow(ip) {
|
||||
func (al *RemoteAllowList) Allow(vpnAddr netip.Addr, udpAddr netip.Addr) bool {
|
||||
if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) {
|
||||
return false
|
||||
}
|
||||
return al.AllowList.Allow(ip)
|
||||
return al.AllowList.Allow(udpAddr)
|
||||
}
|
||||
|
||||
func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList {
|
||||
func (al *RemoteAllowList) AllowAll(vpnAddrs []netip.Addr, udpAddr netip.Addr) bool {
|
||||
if !al.AllowList.Allow(udpAddr) {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, vpnAddr := range vpnAddrs {
|
||||
if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (al *RemoteAllowList) getInsideAllowList(vpnAddr netip.Addr) *AllowList {
|
||||
if al.insideAllowLists != nil {
|
||||
inside, ok := al.insideAllowLists.Lookup(vpnIp)
|
||||
inside, ok := al.insideAllowLists.Lookup(vpnAddr)
|
||||
if ok {
|
||||
return inside
|
||||
}
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewAllowListFromConfig(t *testing.T) {
|
||||
@ -18,21 +19,21 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
||||
"192.168.0.0": true,
|
||||
}
|
||||
r, err := newAllowListFromConfig(c, "allowlist", nil)
|
||||
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
|
||||
require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
|
||||
assert.Nil(t, r)
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"192.168.0.0/16": "abc",
|
||||
}
|
||||
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||
assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
|
||||
require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"192.168.0.0/16": true,
|
||||
"10.0.0.0/8": false,
|
||||
}
|
||||
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
|
||||
require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"0.0.0.0/0": true,
|
||||
@ -42,7 +43,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
||||
"fd00:fd00::/16": false,
|
||||
}
|
||||
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
|
||||
require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"0.0.0.0/0": true,
|
||||
@ -75,7 +76,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
||||
},
|
||||
}
|
||||
lr, err := NewLocalAllowListFromConfig(c, "allowlist")
|
||||
assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
|
||||
require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"interfaces": map[interface{}]interface{}{
|
||||
@ -84,7 +85,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
||||
},
|
||||
}
|
||||
lr, err = NewLocalAllowListFromConfig(c, "allowlist")
|
||||
assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
|
||||
require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"interfaces": map[interface{}]interface{}{
|
||||
@ -98,7 +99,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAllowList_Allow(t *testing.T) {
|
||||
assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1")))
|
||||
assert.True(t, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1")))
|
||||
|
||||
tree := new(bart.Table[bool])
|
||||
tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true)
|
||||
@ -111,17 +112,17 @@ func TestAllowList_Allow(t *testing.T) {
|
||||
tree.Insert(netip.MustParsePrefix("::2/128"), false)
|
||||
al := &AllowList{cidrTree: tree}
|
||||
|
||||
assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1")))
|
||||
assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4")))
|
||||
assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42")))
|
||||
assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41")))
|
||||
assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1")))
|
||||
assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1")))
|
||||
assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2")))
|
||||
assert.True(t, al.Allow(netip.MustParseAddr("1.1.1.1")))
|
||||
assert.False(t, al.Allow(netip.MustParseAddr("10.0.0.4")))
|
||||
assert.True(t, al.Allow(netip.MustParseAddr("10.42.42.42")))
|
||||
assert.False(t, al.Allow(netip.MustParseAddr("10.42.42.41")))
|
||||
assert.True(t, al.Allow(netip.MustParseAddr("10.42.0.1")))
|
||||
assert.True(t, al.Allow(netip.MustParseAddr("::1")))
|
||||
assert.False(t, al.Allow(netip.MustParseAddr("::2")))
|
||||
}
|
||||
|
||||
func TestLocalAllowList_AllowName(t *testing.T) {
|
||||
assert.Equal(t, true, ((*LocalAllowList)(nil)).AllowName("docker0"))
|
||||
assert.True(t, ((*LocalAllowList)(nil)).AllowName("docker0"))
|
||||
|
||||
rules := []AllowListNameRule{
|
||||
{Name: regexp.MustCompile("^docker.*$"), Allow: false},
|
||||
@ -129,9 +130,9 @@ func TestLocalAllowList_AllowName(t *testing.T) {
|
||||
}
|
||||
al := &LocalAllowList{nameRules: rules}
|
||||
|
||||
assert.Equal(t, false, al.AllowName("docker0"))
|
||||
assert.Equal(t, false, al.AllowName("tun0"))
|
||||
assert.Equal(t, true, al.AllowName("eth0"))
|
||||
assert.False(t, al.AllowName("docker0"))
|
||||
assert.False(t, al.AllowName("tun0"))
|
||||
assert.True(t, al.AllowName("eth0"))
|
||||
|
||||
rules = []AllowListNameRule{
|
||||
{Name: regexp.MustCompile("^eth.*$"), Allow: true},
|
||||
@ -139,7 +140,7 @@ func TestLocalAllowList_AllowName(t *testing.T) {
|
||||
}
|
||||
al = &LocalAllowList{nameRules: rules}
|
||||
|
||||
assert.Equal(t, false, al.AllowName("docker0"))
|
||||
assert.Equal(t, true, al.AllowName("eth0"))
|
||||
assert.Equal(t, true, al.AllowName("ens5"))
|
||||
assert.False(t, al.AllowName("docker0"))
|
||||
assert.True(t, al.AllowName("eth0"))
|
||||
assert.True(t, al.AllowName("ens5"))
|
||||
}
|
||||
|
||||
@ -21,7 +21,11 @@ type calculatedRemote struct {
|
||||
port uint32
|
||||
}
|
||||
|
||||
func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
|
||||
func newCalculatedRemote(cidr, maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
|
||||
if maskCidr.Addr().BitLen() != cidr.Addr().BitLen() {
|
||||
return nil, fmt.Errorf("invalid mask: %s for cidr: %s", maskCidr, cidr)
|
||||
}
|
||||
|
||||
masked := maskCidr.Masked()
|
||||
if port < 0 || port > math.MaxUint16 {
|
||||
return nil, fmt.Errorf("invalid port: %d", port)
|
||||
@ -38,32 +42,38 @@ func (c *calculatedRemote) String() string {
|
||||
return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
|
||||
}
|
||||
|
||||
func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort {
|
||||
// Combine the masked bytes of the "mask" IP with the unmasked bytes
|
||||
// of the overlay IP
|
||||
if c.ipNet.Addr().Is4() {
|
||||
return c.apply4(ip)
|
||||
}
|
||||
return c.apply6(ip)
|
||||
}
|
||||
|
||||
func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort {
|
||||
//TODO: IPV6-WORK this can be less crappy
|
||||
func (c *calculatedRemote) ApplyV4(addr netip.Addr) *V4AddrPort {
|
||||
// Combine the masked bytes of the "mask" IP with the unmasked bytes of the overlay IP
|
||||
maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
|
||||
mask := binary.BigEndian.Uint32(maskb[:])
|
||||
|
||||
b := c.mask.Addr().As4()
|
||||
maskIp := binary.BigEndian.Uint32(b[:])
|
||||
maskAddr := binary.BigEndian.Uint32(b[:])
|
||||
|
||||
b = ip.As4()
|
||||
intIp := binary.BigEndian.Uint32(b[:])
|
||||
b = addr.As4()
|
||||
intAddr := binary.BigEndian.Uint32(b[:])
|
||||
|
||||
return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port}
|
||||
return &V4AddrPort{(maskAddr & mask) | (intAddr & ^mask), c.port}
|
||||
}
|
||||
|
||||
func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort {
|
||||
//TODO: IPV6-WORK
|
||||
panic("Can not calculate ipv6 remote addresses")
|
||||
func (c *calculatedRemote) ApplyV6(addr netip.Addr) *V6AddrPort {
|
||||
mask := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
|
||||
maskAddr := c.mask.Addr().As16()
|
||||
calcAddr := addr.As16()
|
||||
|
||||
ap := V6AddrPort{Port: c.port}
|
||||
|
||||
maskb := binary.BigEndian.Uint64(mask[:8])
|
||||
maskAddrb := binary.BigEndian.Uint64(maskAddr[:8])
|
||||
calcAddrb := binary.BigEndian.Uint64(calcAddr[:8])
|
||||
ap.Hi = (maskAddrb & maskb) | (calcAddrb & ^maskb)
|
||||
|
||||
maskb = binary.BigEndian.Uint64(mask[8:])
|
||||
maskAddrb = binary.BigEndian.Uint64(maskAddr[8:])
|
||||
calcAddrb = binary.BigEndian.Uint64(calcAddr[8:])
|
||||
ap.Lo = (maskAddrb & maskb) | (calcAddrb & ^maskb)
|
||||
|
||||
return &ap
|
||||
}
|
||||
|
||||
func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
|
||||
@ -89,8 +99,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
|
||||
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
||||
}
|
||||
|
||||
//TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here
|
||||
entry, err := newCalculatedRemotesListFromConfig(rawValue)
|
||||
entry, err := newCalculatedRemotesListFromConfig(cidr, rawValue)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
|
||||
}
|
||||
@ -101,7 +110,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
|
||||
return calculatedRemotes, nil
|
||||
}
|
||||
|
||||
func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
|
||||
func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculatedRemote, error) {
|
||||
rawList, ok := raw.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw)
|
||||
@ -109,7 +118,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
|
||||
|
||||
var l []*calculatedRemote
|
||||
for _, e := range rawList {
|
||||
c, err := newCalculatedRemotesEntryFromConfig(e)
|
||||
c, err := newCalculatedRemotesEntryFromConfig(cidr, e)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("calculated_remotes entry: %w", err)
|
||||
}
|
||||
@ -119,7 +128,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
|
||||
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
|
||||
rawMap, ok := raw.(map[any]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid type: %T", raw)
|
||||
@ -155,5 +164,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
|
||||
return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
|
||||
}
|
||||
|
||||
return newCalculatedRemote(maskCidr, port)
|
||||
return newCalculatedRemote(cidr, maskCidr, port)
|
||||
}
|
||||
|
||||
@ -9,17 +9,73 @@ import (
|
||||
)
|
||||
|
||||
func TestCalculatedRemoteApply(t *testing.T) {
|
||||
ipNet, err := netip.ParsePrefix("192.168.1.0/24")
|
||||
require.NoError(t, err)
|
||||
|
||||
c, err := newCalculatedRemote(ipNet, 4242)
|
||||
// Test v4 addresses
|
||||
ipNet := netip.MustParsePrefix("192.168.1.0/24")
|
||||
c, err := newCalculatedRemote(ipNet, ipNet, 4242)
|
||||
require.NoError(t, err)
|
||||
|
||||
input, err := netip.ParseAddr("10.0.10.182")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
expected, err := netip.ParseAddr("192.168.1.182")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input))
|
||||
assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
|
||||
|
||||
// Test v6 addresses
|
||||
ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff::0/64")
|
||||
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
|
||||
require.NoError(t, err)
|
||||
|
||||
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
||||
require.NoError(t, err)
|
||||
|
||||
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
||||
|
||||
// Test v6 addresses part 2
|
||||
ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff:ffff::0/80")
|
||||
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
|
||||
require.NoError(t, err)
|
||||
|
||||
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
||||
require.NoError(t, err)
|
||||
|
||||
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
||||
|
||||
// Test v6 addresses part 2
|
||||
ipNet = netip.MustParsePrefix("ffff:ffff:ffff::0/48")
|
||||
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
|
||||
require.NoError(t, err)
|
||||
|
||||
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
||||
require.NoError(t, err)
|
||||
|
||||
expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
||||
}
|
||||
|
||||
func Test_newCalculatedRemote(t *testing.T) {
|
||||
c, err := newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
|
||||
require.EqualError(t, err, "invalid mask: 1.0.0.0/32 for cidr: 1::1/128")
|
||||
require.Nil(t, c)
|
||||
|
||||
c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1::1/128"), 4242)
|
||||
require.EqualError(t, err, "invalid mask: 1::1/128 for cidr: 1.0.0.0/32")
|
||||
require.Nil(t, c)
|
||||
|
||||
c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, c)
|
||||
|
||||
c, err = newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1::1/128"), 4242)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, c)
|
||||
}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
GO111MODULE = on
|
||||
export GO111MODULE
|
||||
|
||||
cert.pb.go: cert.proto .FORCE
|
||||
cert_v1.pb.go: cert_v1.proto .FORCE
|
||||
go build google.golang.org/protobuf/cmd/protoc-gen-go
|
||||
PATH="$(CURDIR):$(PATH)" protoc --go_out=. --go_opt=paths=source_relative $<
|
||||
rm protoc-gen-go
|
||||
|
||||
@ -2,14 +2,25 @@
|
||||
|
||||
This is a library for interacting with `nebula` style certificates and authorities.
|
||||
|
||||
A `protobuf` definition of the certificate format is also included
|
||||
There are now 2 versions of `nebula` certificates:
|
||||
|
||||
### Compiling the protobuf definition
|
||||
## v1
|
||||
|
||||
Make sure you have `protoc` installed.
|
||||
This version is deprecated.
|
||||
|
||||
A `protobuf` definition of the certificate format is included at `cert_v1.proto`
|
||||
|
||||
To compile the definition you will need `protoc` installed.
|
||||
|
||||
To compile for `go` with the same version of protobuf specified in go.mod:
|
||||
|
||||
```bash
|
||||
make
|
||||
make proto
|
||||
```
|
||||
|
||||
## v2
|
||||
|
||||
This is the latest version which uses asn.1 DER encoding. It can support ipv4 and ipv6 and tolerate
|
||||
future certificate changes better than v1.
|
||||
|
||||
`cert_v2.asn1` defines the wire format and can be used to compile marshalers.
|
||||
52
cert/asn1.go
Normal file
52
cert/asn1.go
Normal file
@ -0,0 +1,52 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/cryptobyte"
|
||||
"golang.org/x/crypto/cryptobyte/asn1"
|
||||
)
|
||||
|
||||
// readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value
|
||||
// https://github.com/golang/go/issues/64811#issuecomment-1944446920
|
||||
func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool {
|
||||
var present bool
|
||||
var child cryptobyte.String
|
||||
if !b.ReadOptionalASN1(&child, &present, tag) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !present {
|
||||
*out = defaultValue
|
||||
return true
|
||||
}
|
||||
|
||||
// Ensure we have 1 byte
|
||||
if len(child) == 1 {
|
||||
*out = child[0] > 0
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value
|
||||
// Similar issue as with readOptionalASN1Boolean
|
||||
func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool {
|
||||
var present bool
|
||||
var child cryptobyte.String
|
||||
if !b.ReadOptionalASN1(&child, &present, tag) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !present {
|
||||
*out = defaultValue
|
||||
return true
|
||||
}
|
||||
|
||||
// Ensure we have 1 byte
|
||||
if len(child) == 1 {
|
||||
*out = child[0]
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
140
cert/ca.go
140
cert/ca.go
@ -1,140 +0,0 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type NebulaCAPool struct {
|
||||
CAs map[string]*NebulaCertificate
|
||||
certBlocklist map[string]struct{}
|
||||
}
|
||||
|
||||
// NewCAPool creates a CAPool
|
||||
func NewCAPool() *NebulaCAPool {
|
||||
ca := NebulaCAPool{
|
||||
CAs: make(map[string]*NebulaCertificate),
|
||||
certBlocklist: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
return &ca
|
||||
}
|
||||
|
||||
// NewCAPoolFromBytes will create a new CA pool from the provided
|
||||
// input bytes, which must be a PEM-encoded set of nebula certificates.
|
||||
// If the pool contains any expired certificates, an ErrExpired will be
|
||||
// returned along with the pool. The caller must handle any such errors.
|
||||
func NewCAPoolFromBytes(caPEMs []byte) (*NebulaCAPool, error) {
|
||||
pool := NewCAPool()
|
||||
var err error
|
||||
var expired bool
|
||||
for {
|
||||
caPEMs, err = pool.AddCACertificate(caPEMs)
|
||||
if errors.Is(err, ErrExpired) {
|
||||
expired = true
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if expired {
|
||||
return pool, ErrExpired
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// AddCACertificate verifies a Nebula CA certificate and adds it to the pool
|
||||
// Only the first pem encoded object will be consumed, any remaining bytes are returned.
|
||||
// Parsed certificates will be verified and must be a CA
|
||||
func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) {
|
||||
c, pemBytes, err := UnmarshalNebulaCertificateFromPEM(pemBytes)
|
||||
if err != nil {
|
||||
return pemBytes, err
|
||||
}
|
||||
|
||||
if !c.Details.IsCA {
|
||||
return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotCA)
|
||||
}
|
||||
|
||||
if !c.CheckSignature(c.Details.PublicKey) {
|
||||
return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotSelfSigned)
|
||||
}
|
||||
|
||||
sum, err := c.Sha256Sum()
|
||||
if err != nil {
|
||||
return pemBytes, fmt.Errorf("could not calculate shasum for provided CA; error: %s; %s", err, c.Details.Name)
|
||||
}
|
||||
|
||||
ncp.CAs[sum] = c
|
||||
if c.Expired(time.Now()) {
|
||||
return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrExpired)
|
||||
}
|
||||
|
||||
return pemBytes, nil
|
||||
}
|
||||
|
||||
// BlocklistFingerprint adds a cert fingerprint to the blocklist
|
||||
func (ncp *NebulaCAPool) BlocklistFingerprint(f string) {
|
||||
ncp.certBlocklist[f] = struct{}{}
|
||||
}
|
||||
|
||||
// ResetCertBlocklist removes all previously blocklisted cert fingerprints
|
||||
func (ncp *NebulaCAPool) ResetCertBlocklist() {
|
||||
ncp.certBlocklist = make(map[string]struct{})
|
||||
}
|
||||
|
||||
// NOTE: This uses an internal cache for Sha256Sum() that will not be invalidated
|
||||
// automatically if you manually change any fields in the NebulaCertificate.
|
||||
func (ncp *NebulaCAPool) IsBlocklisted(c *NebulaCertificate) bool {
|
||||
return ncp.isBlocklistedWithCache(c, false)
|
||||
}
|
||||
|
||||
// IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted
|
||||
func (ncp *NebulaCAPool) isBlocklistedWithCache(c *NebulaCertificate, useCache bool) bool {
|
||||
h, err := c.sha256SumWithCache(useCache)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if _, ok := ncp.certBlocklist[h]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetCAForCert attempts to return the signing certificate for the provided certificate.
|
||||
// No signature validation is performed
|
||||
func (ncp *NebulaCAPool) GetCAForCert(c *NebulaCertificate) (*NebulaCertificate, error) {
|
||||
if c.Details.Issuer == "" {
|
||||
return nil, fmt.Errorf("no issuer in certificate")
|
||||
}
|
||||
|
||||
signer, ok := ncp.CAs[c.Details.Issuer]
|
||||
if ok {
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("could not find ca for the certificate")
|
||||
}
|
||||
|
||||
// GetFingerprints returns an array of trusted CA fingerprints
|
||||
func (ncp *NebulaCAPool) GetFingerprints() []string {
|
||||
fp := make([]string, len(ncp.CAs))
|
||||
|
||||
i := 0
|
||||
for k := range ncp.CAs {
|
||||
fp[i] = k
|
||||
i++
|
||||
}
|
||||
|
||||
return fp
|
||||
}
|
||||
296
cert/ca_pool.go
Normal file
296
cert/ca_pool.go
Normal file
@ -0,0 +1,296 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CAPool struct {
|
||||
CAs map[string]*CachedCertificate
|
||||
certBlocklist map[string]struct{}
|
||||
}
|
||||
|
||||
// NewCAPool creates an empty CAPool
|
||||
func NewCAPool() *CAPool {
|
||||
ca := CAPool{
|
||||
CAs: make(map[string]*CachedCertificate),
|
||||
certBlocklist: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
return &ca
|
||||
}
|
||||
|
||||
// NewCAPoolFromPEM will create a new CA pool from the provided
|
||||
// input bytes, which must be a PEM-encoded set of nebula certificates.
|
||||
// If the pool contains any expired certificates, an ErrExpired will be
|
||||
// returned along with the pool. The caller must handle any such errors.
|
||||
func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
|
||||
pool := NewCAPool()
|
||||
var err error
|
||||
var expired bool
|
||||
for {
|
||||
caPEMs, err = pool.AddCAFromPEM(caPEMs)
|
||||
if errors.Is(err, ErrExpired) {
|
||||
expired = true
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if expired {
|
||||
return pool, ErrExpired
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// AddCAFromPEM verifies a Nebula CA certificate and adds it to the pool.
|
||||
// Only the first pem encoded object will be consumed, any remaining bytes are returned.
|
||||
// Parsed certificates will be verified and must be a CA
|
||||
func (ncp *CAPool) AddCAFromPEM(pemBytes []byte) ([]byte, error) {
|
||||
c, pemBytes, err := UnmarshalCertificateFromPEM(pemBytes)
|
||||
if err != nil {
|
||||
return pemBytes, err
|
||||
}
|
||||
|
||||
err = ncp.AddCA(c)
|
||||
if err != nil {
|
||||
return pemBytes, err
|
||||
}
|
||||
|
||||
return pemBytes, nil
|
||||
}
|
||||
|
||||
// AddCA verifies a Nebula CA certificate and adds it to the pool.
|
||||
func (ncp *CAPool) AddCA(c Certificate) error {
|
||||
if !c.IsCA() {
|
||||
return fmt.Errorf("%s: %w", c.Name(), ErrNotCA)
|
||||
}
|
||||
|
||||
if !c.CheckSignature(c.PublicKey()) {
|
||||
return fmt.Errorf("%s: %w", c.Name(), ErrNotSelfSigned)
|
||||
}
|
||||
|
||||
sum, err := c.Fingerprint()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not calculate fingerprint for provided CA; error: %w; %s", err, c.Name())
|
||||
}
|
||||
|
||||
cc := &CachedCertificate{
|
||||
Certificate: c,
|
||||
Fingerprint: sum,
|
||||
InvertedGroups: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
for _, g := range c.Groups() {
|
||||
cc.InvertedGroups[g] = struct{}{}
|
||||
}
|
||||
|
||||
ncp.CAs[sum] = cc
|
||||
|
||||
if c.Expired(time.Now()) {
|
||||
return fmt.Errorf("%s: %w", c.Name(), ErrExpired)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BlocklistFingerprint adds a cert fingerprint to the blocklist
|
||||
func (ncp *CAPool) BlocklistFingerprint(f string) {
|
||||
ncp.certBlocklist[f] = struct{}{}
|
||||
}
|
||||
|
||||
// ResetCertBlocklist removes all previously blocklisted cert fingerprints
|
||||
func (ncp *CAPool) ResetCertBlocklist() {
|
||||
ncp.certBlocklist = make(map[string]struct{})
|
||||
}
|
||||
|
||||
// IsBlocklisted tests the provided fingerprint against the pools blocklist.
|
||||
// Returns true if the fingerprint is blocked.
|
||||
func (ncp *CAPool) IsBlocklisted(fingerprint string) bool {
|
||||
if _, ok := ncp.certBlocklist[fingerprint]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// VerifyCertificate verifies the certificate is valid and is signed by a trusted CA in the pool.
|
||||
// If the certificate is valid then the returned CachedCertificate can be used in subsequent verification attempts
|
||||
// to increase performance.
|
||||
func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCertificate, error) {
|
||||
if c == nil {
|
||||
return nil, fmt.Errorf("no certificate")
|
||||
}
|
||||
fp, err := c.Fingerprint()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not calculate fingerprint to verify: %w", err)
|
||||
}
|
||||
|
||||
signer, err := ncp.verify(c, now, fp, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cc := CachedCertificate{
|
||||
Certificate: c,
|
||||
InvertedGroups: make(map[string]struct{}),
|
||||
Fingerprint: fp,
|
||||
signerFingerprint: signer.Fingerprint,
|
||||
}
|
||||
|
||||
for _, g := range c.Groups() {
|
||||
cc.InvertedGroups[g] = struct{}{}
|
||||
}
|
||||
|
||||
return &cc, nil
|
||||
}
|
||||
|
||||
// VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and
|
||||
// is a cheaper operation to perform as a result.
|
||||
func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error {
|
||||
_, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp string) (*CachedCertificate, error) {
|
||||
if ncp.IsBlocklisted(certFp) {
|
||||
return nil, ErrBlockListed
|
||||
}
|
||||
|
||||
signer, err := ncp.GetCAForCert(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if signer.Certificate.Expired(now) {
|
||||
return nil, ErrRootExpired
|
||||
}
|
||||
|
||||
if c.Expired(now) {
|
||||
return nil, ErrExpired
|
||||
}
|
||||
|
||||
// If we are checking a cached certificate then we can bail early here
|
||||
// Either the root is no longer trusted or everything is fine
|
||||
if len(signerFp) > 0 {
|
||||
if signerFp != signer.Fingerprint {
|
||||
return nil, ErrFingerprintMismatch
|
||||
}
|
||||
return signer, nil
|
||||
}
|
||||
if !c.CheckSignature(signer.Certificate.PublicKey()) {
|
||||
return nil, ErrSignatureMismatch
|
||||
}
|
||||
|
||||
err = CheckCAConstraints(signer.Certificate, c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
// GetCAForCert attempts to return the signing certificate for the provided certificate.
|
||||
// No signature validation is performed
|
||||
func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) {
|
||||
issuer := c.Issuer()
|
||||
if issuer == "" {
|
||||
return nil, fmt.Errorf("no issuer in certificate")
|
||||
}
|
||||
|
||||
signer, ok := ncp.CAs[issuer]
|
||||
if ok {
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
return nil, ErrCaNotFound
|
||||
}
|
||||
|
||||
// GetFingerprints returns an array of trusted CA fingerprints
|
||||
func (ncp *CAPool) GetFingerprints() []string {
|
||||
fp := make([]string, len(ncp.CAs))
|
||||
|
||||
i := 0
|
||||
for k := range ncp.CAs {
|
||||
fp[i] = k
|
||||
i++
|
||||
}
|
||||
|
||||
return fp
|
||||
}
|
||||
|
||||
// CheckCAConstraints returns an error if the sub certificate violates constraints present in the signer certificate.
|
||||
func CheckCAConstraints(signer Certificate, sub Certificate) error {
|
||||
return checkCAConstraints(signer, sub.NotBefore(), sub.NotAfter(), sub.Groups(), sub.Networks(), sub.UnsafeNetworks())
|
||||
}
|
||||
|
||||
// checkCAConstraints is a very generic function allowing both Certificates and TBSCertificates to be tested.
|
||||
func checkCAConstraints(signer Certificate, notBefore, notAfter time.Time, groups []string, networks, unsafeNetworks []netip.Prefix) error {
|
||||
// Make sure this cert isn't valid after the root
|
||||
if notAfter.After(signer.NotAfter()) {
|
||||
return fmt.Errorf("certificate expires after signing certificate")
|
||||
}
|
||||
|
||||
// Make sure this cert wasn't valid before the root
|
||||
if notBefore.Before(signer.NotBefore()) {
|
||||
return fmt.Errorf("certificate is valid before the signing certificate")
|
||||
}
|
||||
|
||||
// If the signer has a limited set of groups make sure the cert only contains a subset
|
||||
signerGroups := signer.Groups()
|
||||
if len(signerGroups) > 0 {
|
||||
for _, g := range groups {
|
||||
if !slices.Contains(signerGroups, g) {
|
||||
return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset
|
||||
signingNetworks := signer.Networks()
|
||||
if len(signingNetworks) > 0 {
|
||||
for _, certNetwork := range networks {
|
||||
found := false
|
||||
for _, signingNetwork := range signingNetworks {
|
||||
if signingNetwork.Contains(certNetwork.Addr()) && signingNetwork.Bits() <= certNetwork.Bits() {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return fmt.Errorf("certificate contained a network assignment outside the limitations of the signing ca: %s", certNetwork.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset
|
||||
signingUnsafeNetworks := signer.UnsafeNetworks()
|
||||
if len(signingUnsafeNetworks) > 0 {
|
||||
for _, certUnsafeNetwork := range unsafeNetworks {
|
||||
found := false
|
||||
for _, caNetwork := range signingUnsafeNetworks {
|
||||
if caNetwork.Contains(certUnsafeNetwork.Addr()) && caNetwork.Bits() <= certUnsafeNetwork.Bits() {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return fmt.Errorf("certificate contained an unsafe network assignment outside the limitations of the signing ca: %s", certUnsafeNetwork.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
560
cert/ca_pool_test.go
Normal file
560
cert/ca_pool_test.go
Normal file
@ -0,0 +1,560 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewCAPoolFromBytes(t *testing.T) {
|
||||
noNewLines := `
|
||||
# Current provisional, Remove once everything moves over to the real root.
|
||||
-----BEGIN NEBULA CERTIFICATE-----
|
||||
Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+
|
||||
PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf
|
||||
2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ==
|
||||
-----END NEBULA CERTIFICATE-----
|
||||
# root-ca01
|
||||
-----BEGIN NEBULA CERTIFICATE-----
|
||||
CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br
|
||||
BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye
|
||||
rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA==
|
||||
-----END NEBULA CERTIFICATE-----
|
||||
`
|
||||
|
||||
withNewLines := `
|
||||
# Current provisional, Remove once everything moves over to the real root.
|
||||
|
||||
-----BEGIN NEBULA CERTIFICATE-----
|
||||
Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+
|
||||
PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf
|
||||
2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ==
|
||||
-----END NEBULA CERTIFICATE-----
|
||||
|
||||
# root-ca01
|
||||
|
||||
|
||||
-----BEGIN NEBULA CERTIFICATE-----
|
||||
CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br
|
||||
BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye
|
||||
rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA==
|
||||
-----END NEBULA CERTIFICATE-----
|
||||
|
||||
`
|
||||
|
||||
expired := `
|
||||
# expired certificate
|
||||
-----BEGIN NEBULA CERTIFICATE-----
|
||||
CjMKB2V4cGlyZWQozRwwzRw6ICJSG94CqX8wn5I65Pwn25V6HftVfWeIySVtp2DA
|
||||
7TY/QAESQMaAk5iJT5EnQwK524ZaaHGEJLUqqbh5yyOHhboIGiVTWkFeH3HccTW8
|
||||
Tq5a8AyWDQdfXbtEZ1FwabeHfH5Asw0=
|
||||
-----END NEBULA CERTIFICATE-----
|
||||
`
|
||||
|
||||
p256 := `
|
||||
# p256 certificate
|
||||
-----BEGIN NEBULA CERTIFICATE-----
|
||||
CmQKEG5lYnVsYSBQMjU2IHRlc3QozRwwzbjM8K8HOkEEdrmmg40zQp44AkMq6DZp
|
||||
k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
|
||||
+0ABoAYBEkcwRQIgVoTg38L7uWku9xQgsr06kxZ/viQLOO/w1Qj1vFUEnhcCIQCq
|
||||
75SjTiV92kv/1GcbT3wWpAZQQDBiUHVMVmh1822szA==
|
||||
-----END NEBULA CERTIFICATE-----
|
||||
`
|
||||
|
||||
rootCA := certificateV1{
|
||||
details: detailsV1{
|
||||
name: "nebula root ca",
|
||||
},
|
||||
}
|
||||
|
||||
rootCA01 := certificateV1{
|
||||
details: detailsV1{
|
||||
name: "nebula root ca 01",
|
||||
},
|
||||
}
|
||||
|
||||
rootCAP256 := certificateV1{
|
||||
details: detailsV1{
|
||||
name: "nebula P256 test",
|
||||
},
|
||||
}
|
||||
|
||||
p, err := NewCAPoolFromPEM([]byte(noNewLines))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
|
||||
assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
|
||||
|
||||
pp, err := NewCAPoolFromPEM([]byte(withNewLines))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
|
||||
assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
|
||||
|
||||
// expired cert, no valid certs
|
||||
ppp, err := NewCAPoolFromPEM([]byte(expired))
|
||||
assert.Equal(t, ErrExpired, err)
|
||||
assert.Equal(t, "expired", ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name())
|
||||
|
||||
// expired cert, with valid certs
|
||||
pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
|
||||
assert.Equal(t, ErrExpired, err)
|
||||
assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
|
||||
assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
|
||||
assert.Equal(t, "expired", pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name())
|
||||
assert.Len(t, pppp.CAs, 3)
|
||||
|
||||
ppppp, err := NewCAPoolFromPEM([]byte(p256))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name)
|
||||
assert.Len(t, ppppp.CAs, 1)
|
||||
}
|
||||
|
||||
func TestCertificateV1_Verify(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
|
||||
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||
|
||||
caPool := NewCAPool()
|
||||
require.NoError(t, caPool.AddCA(ca))
|
||||
|
||||
f, err := c.Fingerprint()
|
||||
require.NoError(t, err)
|
||||
caPool.BlocklistFingerprint(f)
|
||||
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.EqualError(t, err, "certificate is in the block list")
|
||||
|
||||
caPool.ResetCertBlocklist()
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
||||
require.EqualError(t, err, "root certificate is expired")
|
||||
|
||||
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
||||
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
|
||||
})
|
||||
|
||||
// Test group assertion
|
||||
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
||||
caPem, err := ca.MarshalPEM()
|
||||
require.NoError(t, err)
|
||||
|
||||
caPool = NewCAPool()
|
||||
b, err := caPool.AddCAFromPEM(caPem)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, b)
|
||||
|
||||
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
||||
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
|
||||
})
|
||||
|
||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
||||
require.NoError(t, err)
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCertificateV1_VerifyP256(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
|
||||
c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||
|
||||
caPool := NewCAPool()
|
||||
require.NoError(t, caPool.AddCA(ca))
|
||||
|
||||
f, err := c.Fingerprint()
|
||||
require.NoError(t, err)
|
||||
caPool.BlocklistFingerprint(f)
|
||||
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.EqualError(t, err, "certificate is in the block list")
|
||||
|
||||
caPool.ResetCertBlocklist()
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
||||
require.EqualError(t, err, "root certificate is expired")
|
||||
|
||||
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
||||
NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||
})
|
||||
|
||||
// Test group assertion
|
||||
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
||||
caPem, err := ca.MarshalPEM()
|
||||
require.NoError(t, err)
|
||||
|
||||
caPool = NewCAPool()
|
||||
b, err := caPool.AddCAFromPEM(caPem)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, b)
|
||||
|
||||
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
||||
NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
|
||||
})
|
||||
|
||||
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCertificateV1_Verify_IPs(t *testing.T) {
|
||||
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
|
||||
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
|
||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
||||
|
||||
caPem, err := ca.MarshalPEM()
|
||||
require.NoError(t, err)
|
||||
|
||||
caPool := NewCAPool()
|
||||
b, err := caPool.AddCAFromPEM(caPem)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, b)
|
||||
|
||||
// ip is outside the network
|
||||
cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
|
||||
cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
|
||||
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
|
||||
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
||||
})
|
||||
|
||||
// ip is outside the network reversed order of above
|
||||
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
|
||||
cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
|
||||
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
|
||||
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
||||
})
|
||||
|
||||
// ip is within the network but mask is outside
|
||||
cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
|
||||
cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
|
||||
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
|
||||
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
||||
})
|
||||
|
||||
// ip is within the network but mask is outside reversed order of above
|
||||
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
|
||||
cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
|
||||
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
|
||||
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
||||
})
|
||||
|
||||
// ip and mask are within the network
|
||||
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
|
||||
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
||||
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exact matches
|
||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
||||
require.NoError(t, err)
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exact matches reversed
|
||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
|
||||
require.NoError(t, err)
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exact matches reversed with just 1
|
||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
|
||||
require.NoError(t, err)
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCertificateV1_Verify_Subnets(t *testing.T) {
|
||||
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
|
||||
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
|
||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
||||
|
||||
caPem, err := ca.MarshalPEM()
|
||||
require.NoError(t, err)
|
||||
|
||||
caPool := NewCAPool()
|
||||
b, err := caPool.AddCAFromPEM(caPem)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, b)
|
||||
|
||||
// ip is outside the network
|
||||
cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
|
||||
cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
|
||||
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
|
||||
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
||||
})
|
||||
|
||||
// ip is outside the network reversed order of above
|
||||
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
|
||||
cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
|
||||
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
|
||||
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
||||
})
|
||||
|
||||
// ip is within the network but mask is outside
|
||||
cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
|
||||
cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
|
||||
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
|
||||
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
||||
})
|
||||
|
||||
// ip is within the network but mask is outside reversed order of above
|
||||
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
|
||||
cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
|
||||
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
|
||||
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
||||
})
|
||||
|
||||
// ip and mask are within the network
|
||||
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
|
||||
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
||||
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
||||
require.NoError(t, err)
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exact matches
|
||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
||||
require.NoError(t, err)
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exact matches reversed
|
||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
|
||||
require.NoError(t, err)
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exact matches reversed with just 1
|
||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
|
||||
require.NoError(t, err)
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCertificateV2_Verify(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
|
||||
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||
|
||||
caPool := NewCAPool()
|
||||
require.NoError(t, caPool.AddCA(ca))
|
||||
|
||||
f, err := c.Fingerprint()
|
||||
require.NoError(t, err)
|
||||
caPool.BlocklistFingerprint(f)
|
||||
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.EqualError(t, err, "certificate is in the block list")
|
||||
|
||||
caPool.ResetCertBlocklist()
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
||||
require.EqualError(t, err, "root certificate is expired")
|
||||
|
||||
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
||||
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
|
||||
})
|
||||
|
||||
// Test group assertion
|
||||
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
||||
caPem, err := ca.MarshalPEM()
|
||||
require.NoError(t, err)
|
||||
|
||||
caPool = NewCAPool()
|
||||
b, err := caPool.AddCAFromPEM(caPem)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, b)
|
||||
|
||||
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
||||
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
|
||||
})
|
||||
|
||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
||||
require.NoError(t, err)
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCertificateV2_VerifyP256(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
|
||||
c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||
|
||||
caPool := NewCAPool()
|
||||
require.NoError(t, caPool.AddCA(ca))
|
||||
|
||||
f, err := c.Fingerprint()
|
||||
require.NoError(t, err)
|
||||
caPool.BlocklistFingerprint(f)
|
||||
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.EqualError(t, err, "certificate is in the block list")
|
||||
|
||||
caPool.ResetCertBlocklist()
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
||||
require.EqualError(t, err, "root certificate is expired")
|
||||
|
||||
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
||||
NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||
})
|
||||
|
||||
// Test group assertion
|
||||
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
||||
caPem, err := ca.MarshalPEM()
|
||||
require.NoError(t, err)
|
||||
|
||||
caPool = NewCAPool()
|
||||
b, err := caPool.AddCAFromPEM(caPem)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, b)
|
||||
|
||||
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
||||
NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
|
||||
})
|
||||
|
||||
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCertificateV2_Verify_IPs(t *testing.T) {
|
||||
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
|
||||
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
|
||||
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
||||
|
||||
caPem, err := ca.MarshalPEM()
|
||||
require.NoError(t, err)
|
||||
|
||||
caPool := NewCAPool()
|
||||
b, err := caPool.AddCAFromPEM(caPem)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, b)
|
||||
|
||||
// ip is outside the network
|
||||
cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
|
||||
cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
|
||||
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
|
||||
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
||||
})
|
||||
|
||||
// ip is outside the network reversed order of above
|
||||
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
|
||||
cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
|
||||
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
|
||||
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
||||
})
|
||||
|
||||
// ip is within the network but mask is outside
|
||||
cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
|
||||
cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
|
||||
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
|
||||
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
||||
})
|
||||
|
||||
// ip is within the network but mask is outside reversed order of above
|
||||
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
|
||||
cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
|
||||
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
|
||||
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
||||
})
|
||||
|
||||
// ip and mask are within the network
|
||||
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
|
||||
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
||||
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exact matches
|
||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
||||
require.NoError(t, err)
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exact matches reversed
|
||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
|
||||
require.NoError(t, err)
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exact matches reversed with just 1
|
||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
|
||||
require.NoError(t, err)
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCertificateV2_Verify_Subnets(t *testing.T) {
|
||||
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
|
||||
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
|
||||
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
||||
|
||||
caPem, err := ca.MarshalPEM()
|
||||
require.NoError(t, err)
|
||||
|
||||
caPool := NewCAPool()
|
||||
b, err := caPool.AddCAFromPEM(caPem)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, b)
|
||||
|
||||
// ip is outside the network
|
||||
cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
|
||||
cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
|
||||
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
|
||||
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
||||
})
|
||||
|
||||
// ip is outside the network reversed order of above
|
||||
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
|
||||
cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
|
||||
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
|
||||
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
||||
})
|
||||
|
||||
// ip is within the network but mask is outside
|
||||
cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
|
||||
cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
|
||||
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
|
||||
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
||||
})
|
||||
|
||||
// ip is within the network but mask is outside reversed order of above
|
||||
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
|
||||
cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
|
||||
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
|
||||
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
||||
})
|
||||
|
||||
// ip and mask are within the network
|
||||
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
|
||||
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
||||
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
||||
require.NoError(t, err)
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exact matches
|
||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
||||
require.NoError(t, err)
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exact matches reversed
|
||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
|
||||
require.NoError(t, err)
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exact matches reversed with just 1
|
||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
|
||||
require.NoError(t, err)
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
1159
cert/cert.go
1159
cert/cert.go
File diff suppressed because it is too large
Load Diff
1230
cert/cert_test.go
1230
cert/cert_test.go
File diff suppressed because it is too large
Load Diff
489
cert/cert_v1.go
Normal file
489
cert/cert_v1.go
Normal file
@ -0,0 +1,489 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdh"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const publicKeyLen = 32
|
||||
|
||||
type certificateV1 struct {
|
||||
details detailsV1
|
||||
signature []byte
|
||||
}
|
||||
|
||||
type detailsV1 struct {
|
||||
name string
|
||||
networks []netip.Prefix
|
||||
unsafeNetworks []netip.Prefix
|
||||
groups []string
|
||||
notBefore time.Time
|
||||
notAfter time.Time
|
||||
publicKey []byte
|
||||
isCA bool
|
||||
issuer string
|
||||
|
||||
curve Curve
|
||||
}
|
||||
|
||||
type m map[string]interface{}
|
||||
|
||||
func (c *certificateV1) Version() Version {
|
||||
return Version1
|
||||
}
|
||||
|
||||
func (c *certificateV1) Curve() Curve {
|
||||
return c.details.curve
|
||||
}
|
||||
|
||||
func (c *certificateV1) Groups() []string {
|
||||
return c.details.groups
|
||||
}
|
||||
|
||||
func (c *certificateV1) IsCA() bool {
|
||||
return c.details.isCA
|
||||
}
|
||||
|
||||
func (c *certificateV1) Issuer() string {
|
||||
return c.details.issuer
|
||||
}
|
||||
|
||||
func (c *certificateV1) Name() string {
|
||||
return c.details.name
|
||||
}
|
||||
|
||||
func (c *certificateV1) Networks() []netip.Prefix {
|
||||
return c.details.networks
|
||||
}
|
||||
|
||||
func (c *certificateV1) NotAfter() time.Time {
|
||||
return c.details.notAfter
|
||||
}
|
||||
|
||||
func (c *certificateV1) NotBefore() time.Time {
|
||||
return c.details.notBefore
|
||||
}
|
||||
|
||||
func (c *certificateV1) PublicKey() []byte {
|
||||
return c.details.publicKey
|
||||
}
|
||||
|
||||
func (c *certificateV1) Signature() []byte {
|
||||
return c.signature
|
||||
}
|
||||
|
||||
func (c *certificateV1) UnsafeNetworks() []netip.Prefix {
|
||||
return c.details.unsafeNetworks
|
||||
}
|
||||
|
||||
func (c *certificateV1) Fingerprint() (string, error) {
|
||||
b, err := c.Marshal()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
sum := sha256.Sum256(b)
|
||||
return hex.EncodeToString(sum[:]), nil
|
||||
}
|
||||
|
||||
func (c *certificateV1) CheckSignature(key []byte) bool {
|
||||
b, err := proto.Marshal(c.getRawDetails())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
switch c.details.curve {
|
||||
case Curve_CURVE25519:
|
||||
return ed25519.Verify(key, b, c.signature)
|
||||
case Curve_P256:
|
||||
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
||||
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
||||
hashed := sha256.Sum256(b)
|
||||
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *certificateV1) Expired(t time.Time) bool {
|
||||
return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
|
||||
}
|
||||
|
||||
func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
|
||||
if curve != c.details.curve {
|
||||
return fmt.Errorf("curve in cert and private key supplied don't match")
|
||||
}
|
||||
if c.details.isCA {
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
// the call to PublicKey below will panic slice bounds out of range otherwise
|
||||
if len(key) != ed25519.PrivateKeySize {
|
||||
return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
|
||||
}
|
||||
|
||||
if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
|
||||
return fmt.Errorf("public key in cert and private key supplied don't match")
|
||||
}
|
||||
case Curve_P256:
|
||||
privkey, err := ecdh.P256().NewPrivateKey(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot parse private key as P256: %w", err)
|
||||
}
|
||||
pub := privkey.PublicKey().Bytes()
|
||||
if !bytes.Equal(pub, c.details.publicKey) {
|
||||
return fmt.Errorf("public key in cert and private key supplied don't match")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("invalid curve: %s", curve)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var pub []byte
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
var err error
|
||||
pub, err = curve25519.X25519(key, curve25519.Basepoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case Curve_P256:
|
||||
privkey, err := ecdh.P256().NewPrivateKey(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pub = privkey.PublicKey().Bytes()
|
||||
default:
|
||||
return fmt.Errorf("invalid curve: %s", curve)
|
||||
}
|
||||
if !bytes.Equal(pub, c.details.publicKey) {
|
||||
return fmt.Errorf("public key in cert and private key supplied don't match")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRawDetails marshals the raw details into protobuf ready struct
|
||||
func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
|
||||
rd := &RawNebulaCertificateDetails{
|
||||
Name: c.details.name,
|
||||
Groups: c.details.groups,
|
||||
NotBefore: c.details.notBefore.Unix(),
|
||||
NotAfter: c.details.notAfter.Unix(),
|
||||
PublicKey: make([]byte, len(c.details.publicKey)),
|
||||
IsCA: c.details.isCA,
|
||||
Curve: c.details.curve,
|
||||
}
|
||||
|
||||
for _, ipNet := range c.details.networks {
|
||||
mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
|
||||
rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask))
|
||||
}
|
||||
|
||||
for _, ipNet := range c.details.unsafeNetworks {
|
||||
mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
|
||||
rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask))
|
||||
}
|
||||
|
||||
copy(rd.PublicKey, c.details.publicKey[:])
|
||||
|
||||
// I know, this is terrible
|
||||
rd.Issuer, _ = hex.DecodeString(c.details.issuer)
|
||||
|
||||
return rd
|
||||
}
|
||||
|
||||
func (c *certificateV1) String() string {
|
||||
b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
|
||||
if err != nil {
|
||||
return fmt.Sprintf("<error marshalling certificate: %v>", err)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (c *certificateV1) MarshalForHandshakes() ([]byte, error) {
|
||||
pubKey := c.details.publicKey
|
||||
c.details.publicKey = nil
|
||||
rawCertNoKey, err := c.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.details.publicKey = pubKey
|
||||
return rawCertNoKey, nil
|
||||
}
|
||||
|
||||
func (c *certificateV1) Marshal() ([]byte, error) {
|
||||
rc := RawNebulaCertificate{
|
||||
Details: c.getRawDetails(),
|
||||
Signature: c.signature,
|
||||
}
|
||||
|
||||
return proto.Marshal(&rc)
|
||||
}
|
||||
|
||||
func (c *certificateV1) MarshalPEM() ([]byte, error) {
|
||||
b, err := c.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil
|
||||
}
|
||||
|
||||
func (c *certificateV1) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(c.marshalJSON())
|
||||
}
|
||||
|
||||
func (c *certificateV1) marshalJSON() m {
|
||||
fp, _ := c.Fingerprint()
|
||||
return m{
|
||||
"version": Version1,
|
||||
"details": m{
|
||||
"name": c.details.name,
|
||||
"networks": c.details.networks,
|
||||
"unsafeNetworks": c.details.unsafeNetworks,
|
||||
"groups": c.details.groups,
|
||||
"notBefore": c.details.notBefore,
|
||||
"notAfter": c.details.notAfter,
|
||||
"publicKey": fmt.Sprintf("%x", c.details.publicKey),
|
||||
"isCa": c.details.isCA,
|
||||
"issuer": c.details.issuer,
|
||||
"curve": c.details.curve.String(),
|
||||
},
|
||||
"fingerprint": fp,
|
||||
"signature": fmt.Sprintf("%x", c.Signature()),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *certificateV1) Copy() Certificate {
|
||||
nc := &certificateV1{
|
||||
details: detailsV1{
|
||||
name: c.details.name,
|
||||
notBefore: c.details.notBefore,
|
||||
notAfter: c.details.notAfter,
|
||||
publicKey: make([]byte, len(c.details.publicKey)),
|
||||
isCA: c.details.isCA,
|
||||
issuer: c.details.issuer,
|
||||
curve: c.details.curve,
|
||||
},
|
||||
signature: make([]byte, len(c.signature)),
|
||||
}
|
||||
|
||||
if c.details.groups != nil {
|
||||
nc.details.groups = make([]string, len(c.details.groups))
|
||||
copy(nc.details.groups, c.details.groups)
|
||||
}
|
||||
|
||||
if c.details.networks != nil {
|
||||
nc.details.networks = make([]netip.Prefix, len(c.details.networks))
|
||||
copy(nc.details.networks, c.details.networks)
|
||||
}
|
||||
|
||||
if c.details.unsafeNetworks != nil {
|
||||
nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
|
||||
copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
|
||||
}
|
||||
|
||||
copy(nc.signature, c.signature)
|
||||
copy(nc.details.publicKey, c.details.publicKey)
|
||||
|
||||
return nc
|
||||
}
|
||||
|
||||
func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error {
|
||||
c.details = detailsV1{
|
||||
name: t.Name,
|
||||
networks: t.Networks,
|
||||
unsafeNetworks: t.UnsafeNetworks,
|
||||
groups: t.Groups,
|
||||
notBefore: t.NotBefore,
|
||||
notAfter: t.NotAfter,
|
||||
publicKey: t.PublicKey,
|
||||
isCA: t.IsCA,
|
||||
curve: t.Curve,
|
||||
issuer: t.issuer,
|
||||
}
|
||||
|
||||
return c.validate()
|
||||
}
|
||||
|
||||
func (c *certificateV1) validate() error {
|
||||
// Empty names are allowed
|
||||
|
||||
if len(c.details.publicKey) == 0 {
|
||||
return ErrInvalidPublicKey
|
||||
}
|
||||
|
||||
// Original v1 rules allowed multiple networks to be present but ignored all but the first one.
|
||||
// Continue to allow this behavior
|
||||
if !c.details.isCA && len(c.details.networks) == 0 {
|
||||
return NewErrInvalidCertificateProperties("non-CA certificates must contain exactly one network")
|
||||
}
|
||||
|
||||
for _, network := range c.details.networks {
|
||||
if !network.IsValid() || !network.Addr().IsValid() {
|
||||
return NewErrInvalidCertificateProperties("invalid network: %s", network)
|
||||
}
|
||||
|
||||
if network.Addr().Is6() {
|
||||
return NewErrInvalidCertificateProperties("certificate may not contain IPv6 networks: %v", network)
|
||||
}
|
||||
|
||||
if network.Addr().IsUnspecified() {
|
||||
return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
|
||||
}
|
||||
|
||||
if network.Addr().Zone() != "" {
|
||||
return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
|
||||
}
|
||||
}
|
||||
|
||||
for _, network := range c.details.unsafeNetworks {
|
||||
if !network.IsValid() || !network.Addr().IsValid() {
|
||||
return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
|
||||
}
|
||||
|
||||
if network.Addr().Is6() {
|
||||
return NewErrInvalidCertificateProperties("certificate may not contain IPv6 unsafe networks: %v", network)
|
||||
}
|
||||
|
||||
if network.Addr().Zone() != "" {
|
||||
return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
|
||||
}
|
||||
}
|
||||
|
||||
// v1 doesn't bother with sort order or uniqueness of networks or unsafe networks.
|
||||
// We can't modify the unmarshalled data because verification requires re-marshalling and a re-ordered
|
||||
// unsafe networks would result in a different signature.
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *certificateV1) marshalForSigning() ([]byte, error) {
|
||||
b, err := proto.Marshal(c.getRawDetails())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (c *certificateV1) setSignature(b []byte) error {
|
||||
if len(b) == 0 {
|
||||
return ErrEmptySignature
|
||||
}
|
||||
c.signature = b
|
||||
return nil
|
||||
}
|
||||
|
||||
// unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert
|
||||
// if the publicKey is provided here then it is not required to be present in `b`
|
||||
func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) {
|
||||
if len(b) == 0 {
|
||||
return nil, fmt.Errorf("nil byte array")
|
||||
}
|
||||
var rc RawNebulaCertificate
|
||||
err := proto.Unmarshal(b, &rc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rc.Details == nil {
|
||||
return nil, fmt.Errorf("encoded Details was nil")
|
||||
}
|
||||
|
||||
if len(rc.Details.Ips)%2 != 0 {
|
||||
return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found")
|
||||
}
|
||||
|
||||
if len(rc.Details.Subnets)%2 != 0 {
|
||||
return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found")
|
||||
}
|
||||
|
||||
nc := certificateV1{
|
||||
details: detailsV1{
|
||||
name: rc.Details.Name,
|
||||
groups: make([]string, len(rc.Details.Groups)),
|
||||
networks: make([]netip.Prefix, len(rc.Details.Ips)/2),
|
||||
unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2),
|
||||
notBefore: time.Unix(rc.Details.NotBefore, 0),
|
||||
notAfter: time.Unix(rc.Details.NotAfter, 0),
|
||||
publicKey: make([]byte, len(rc.Details.PublicKey)),
|
||||
isCA: rc.Details.IsCA,
|
||||
curve: rc.Details.Curve,
|
||||
},
|
||||
signature: make([]byte, len(rc.Signature)),
|
||||
}
|
||||
|
||||
copy(nc.signature, rc.Signature)
|
||||
copy(nc.details.groups, rc.Details.Groups)
|
||||
nc.details.issuer = hex.EncodeToString(rc.Details.Issuer)
|
||||
|
||||
if len(publicKey) > 0 {
|
||||
nc.details.publicKey = publicKey
|
||||
}
|
||||
|
||||
copy(nc.details.publicKey, rc.Details.PublicKey)
|
||||
|
||||
var ip netip.Addr
|
||||
for i, rawIp := range rc.Details.Ips {
|
||||
if i%2 == 0 {
|
||||
ip = int2addr(rawIp)
|
||||
} else {
|
||||
ones, _ := net.IPMask(int2ip(rawIp)).Size()
|
||||
nc.details.networks[i/2] = netip.PrefixFrom(ip, ones)
|
||||
}
|
||||
}
|
||||
|
||||
for i, rawIp := range rc.Details.Subnets {
|
||||
if i%2 == 0 {
|
||||
ip = int2addr(rawIp)
|
||||
} else {
|
||||
ones, _ := net.IPMask(int2ip(rawIp)).Size()
|
||||
nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones)
|
||||
}
|
||||
}
|
||||
|
||||
err = nc.validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &nc, nil
|
||||
}
|
||||
|
||||
func ip2int(ip []byte) uint32 {
|
||||
if len(ip) == 16 {
|
||||
return binary.BigEndian.Uint32(ip[12:16])
|
||||
}
|
||||
return binary.BigEndian.Uint32(ip)
|
||||
}
|
||||
|
||||
func int2ip(nn uint32) net.IP {
|
||||
ip := make(net.IP, net.IPv4len)
|
||||
binary.BigEndian.PutUint32(ip, nn)
|
||||
return ip
|
||||
}
|
||||
|
||||
func addr2int(addr netip.Addr) uint32 {
|
||||
b := addr.Unmap().As4()
|
||||
return binary.BigEndian.Uint32(b[:])
|
||||
}
|
||||
|
||||
func int2addr(nn uint32) netip.Addr {
|
||||
ip := [4]byte{}
|
||||
binary.BigEndian.PutUint32(ip[:], nn)
|
||||
return netip.AddrFrom4(ip).Unmap()
|
||||
}
|
||||
@ -1,8 +1,8 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.30.0
|
||||
// protoc-gen-go v1.34.2
|
||||
// protoc v3.21.5
|
||||
// source: cert.proto
|
||||
// source: cert_v1.proto
|
||||
|
||||
package cert
|
||||
|
||||
@ -50,11 +50,11 @@ func (x Curve) String() string {
|
||||
}
|
||||
|
||||
func (Curve) Descriptor() protoreflect.EnumDescriptor {
|
||||
return file_cert_proto_enumTypes[0].Descriptor()
|
||||
return file_cert_v1_proto_enumTypes[0].Descriptor()
|
||||
}
|
||||
|
||||
func (Curve) Type() protoreflect.EnumType {
|
||||
return &file_cert_proto_enumTypes[0]
|
||||
return &file_cert_v1_proto_enumTypes[0]
|
||||
}
|
||||
|
||||
func (x Curve) Number() protoreflect.EnumNumber {
|
||||
@ -63,7 +63,7 @@ func (x Curve) Number() protoreflect.EnumNumber {
|
||||
|
||||
// Deprecated: Use Curve.Descriptor instead.
|
||||
func (Curve) EnumDescriptor() ([]byte, []int) {
|
||||
return file_cert_proto_rawDescGZIP(), []int{0}
|
||||
return file_cert_v1_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
type RawNebulaCertificate struct {
|
||||
@ -78,7 +78,7 @@ type RawNebulaCertificate struct {
|
||||
func (x *RawNebulaCertificate) Reset() {
|
||||
*x = RawNebulaCertificate{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_cert_proto_msgTypes[0]
|
||||
mi := &file_cert_v1_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@ -91,7 +91,7 @@ func (x *RawNebulaCertificate) String() string {
|
||||
func (*RawNebulaCertificate) ProtoMessage() {}
|
||||
|
||||
func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_cert_proto_msgTypes[0]
|
||||
mi := &file_cert_v1_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@ -104,7 +104,7 @@ func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use RawNebulaCertificate.ProtoReflect.Descriptor instead.
|
||||
func (*RawNebulaCertificate) Descriptor() ([]byte, []int) {
|
||||
return file_cert_proto_rawDescGZIP(), []int{0}
|
||||
return file_cert_v1_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *RawNebulaCertificate) GetDetails() *RawNebulaCertificateDetails {
|
||||
@ -143,7 +143,7 @@ type RawNebulaCertificateDetails struct {
|
||||
func (x *RawNebulaCertificateDetails) Reset() {
|
||||
*x = RawNebulaCertificateDetails{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_cert_proto_msgTypes[1]
|
||||
mi := &file_cert_v1_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@ -156,7 +156,7 @@ func (x *RawNebulaCertificateDetails) String() string {
|
||||
func (*RawNebulaCertificateDetails) ProtoMessage() {}
|
||||
|
||||
func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_cert_proto_msgTypes[1]
|
||||
mi := &file_cert_v1_proto_msgTypes[1]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@ -169,7 +169,7 @@ func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use RawNebulaCertificateDetails.ProtoReflect.Descriptor instead.
|
||||
func (*RawNebulaCertificateDetails) Descriptor() ([]byte, []int) {
|
||||
return file_cert_proto_rawDescGZIP(), []int{1}
|
||||
return file_cert_v1_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *RawNebulaCertificateDetails) GetName() string {
|
||||
@ -254,7 +254,7 @@ type RawNebulaEncryptedData struct {
|
||||
func (x *RawNebulaEncryptedData) Reset() {
|
||||
*x = RawNebulaEncryptedData{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_cert_proto_msgTypes[2]
|
||||
mi := &file_cert_v1_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@ -267,7 +267,7 @@ func (x *RawNebulaEncryptedData) String() string {
|
||||
func (*RawNebulaEncryptedData) ProtoMessage() {}
|
||||
|
||||
func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_cert_proto_msgTypes[2]
|
||||
mi := &file_cert_v1_proto_msgTypes[2]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@ -280,7 +280,7 @@ func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use RawNebulaEncryptedData.ProtoReflect.Descriptor instead.
|
||||
func (*RawNebulaEncryptedData) Descriptor() ([]byte, []int) {
|
||||
return file_cert_proto_rawDescGZIP(), []int{2}
|
||||
return file_cert_v1_proto_rawDescGZIP(), []int{2}
|
||||
}
|
||||
|
||||
func (x *RawNebulaEncryptedData) GetEncryptionMetadata() *RawNebulaEncryptionMetadata {
|
||||
@ -309,7 +309,7 @@ type RawNebulaEncryptionMetadata struct {
|
||||
func (x *RawNebulaEncryptionMetadata) Reset() {
|
||||
*x = RawNebulaEncryptionMetadata{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_cert_proto_msgTypes[3]
|
||||
mi := &file_cert_v1_proto_msgTypes[3]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@ -322,7 +322,7 @@ func (x *RawNebulaEncryptionMetadata) String() string {
|
||||
func (*RawNebulaEncryptionMetadata) ProtoMessage() {}
|
||||
|
||||
func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_cert_proto_msgTypes[3]
|
||||
mi := &file_cert_v1_proto_msgTypes[3]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@ -335,7 +335,7 @@ func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use RawNebulaEncryptionMetadata.ProtoReflect.Descriptor instead.
|
||||
func (*RawNebulaEncryptionMetadata) Descriptor() ([]byte, []int) {
|
||||
return file_cert_proto_rawDescGZIP(), []int{3}
|
||||
return file_cert_v1_proto_rawDescGZIP(), []int{3}
|
||||
}
|
||||
|
||||
func (x *RawNebulaEncryptionMetadata) GetEncryptionAlgorithm() string {
|
||||
@ -367,7 +367,7 @@ type RawNebulaArgon2Parameters struct {
|
||||
func (x *RawNebulaArgon2Parameters) Reset() {
|
||||
*x = RawNebulaArgon2Parameters{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_cert_proto_msgTypes[4]
|
||||
mi := &file_cert_v1_proto_msgTypes[4]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@ -380,7 +380,7 @@ func (x *RawNebulaArgon2Parameters) String() string {
|
||||
func (*RawNebulaArgon2Parameters) ProtoMessage() {}
|
||||
|
||||
func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_cert_proto_msgTypes[4]
|
||||
mi := &file_cert_v1_proto_msgTypes[4]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@ -393,7 +393,7 @@ func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use RawNebulaArgon2Parameters.ProtoReflect.Descriptor instead.
|
||||
func (*RawNebulaArgon2Parameters) Descriptor() ([]byte, []int) {
|
||||
return file_cert_proto_rawDescGZIP(), []int{4}
|
||||
return file_cert_v1_proto_rawDescGZIP(), []int{4}
|
||||
}
|
||||
|
||||
func (x *RawNebulaArgon2Parameters) GetVersion() int32 {
|
||||
@ -431,87 +431,87 @@ func (x *RawNebulaArgon2Parameters) GetSalt() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
var File_cert_proto protoreflect.FileDescriptor
|
||||
var File_cert_v1_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_cert_proto_rawDesc = []byte{
|
||||
0x0a, 0x0a, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x63, 0x65,
|
||||
0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43,
|
||||
0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a, 0x07, 0x44, 0x65,
|
||||
0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65,
|
||||
0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74,
|
||||
0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x52, 0x07,
|
||||
0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69, 0x67, 0x6e, 0x61,
|
||||
0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53, 0x69, 0x67, 0x6e,
|
||||
0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62,
|
||||
0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65,
|
||||
0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x49, 0x70, 0x73,
|
||||
0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x53,
|
||||
0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x07, 0x53, 0x75,
|
||||
0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18,
|
||||
0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x1c, 0x0a,
|
||||
0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03,
|
||||
0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x4e,
|
||||
0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x4e,
|
||||
0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75, 0x62, 0x6c, 0x69,
|
||||
0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50, 0x75, 0x62, 0x6c,
|
||||
0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, 0x18, 0x08, 0x20,
|
||||
0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, 0x49, 0x73, 0x73,
|
||||
0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65,
|
||||
0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20, 0x01, 0x28, 0x0e,
|
||||
0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65, 0x52, 0x05, 0x63,
|
||||
0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
|
||||
0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, 0x74, 0x61, 0x12,
|
||||
0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74,
|
||||
0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65,
|
||||
0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72,
|
||||
0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x12,
|
||||
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61,
|
||||
0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74,
|
||||
0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65,
|
||||
0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61,
|
||||
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61,
|
||||
0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e,
|
||||
0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
|
||||
0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72,
|
||||
0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61,
|
||||
0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f,
|
||||
0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41,
|
||||
0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x52,
|
||||
0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72,
|
||||
0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41,
|
||||
0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12,
|
||||
0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05,
|
||||
0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x6d,
|
||||
0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72,
|
||||
0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d,
|
||||
0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c,
|
||||
0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e,
|
||||
0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69,
|
||||
0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28,
|
||||
0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75, 0x72, 0x76, 0x65,
|
||||
0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31, 0x39, 0x10, 0x00,
|
||||
0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69,
|
||||
0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71,
|
||||
0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x33,
|
||||
var file_cert_v1_proto_rawDesc = []byte{
|
||||
0x0a, 0x0d, 0x63, 0x65, 0x72, 0x74, 0x5f, 0x76, 0x31, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12,
|
||||
0x04, 0x63, 0x65, 0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
|
||||
0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a,
|
||||
0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21,
|
||||
0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43,
|
||||
0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c,
|
||||
0x73, 0x52, 0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69,
|
||||
0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53,
|
||||
0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77,
|
||||
0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74,
|
||||
0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65,
|
||||
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03,
|
||||
0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18,
|
||||
0x0a, 0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52,
|
||||
0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75,
|
||||
0x70, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73,
|
||||
0x12, 0x1c, 0x0a, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20,
|
||||
0x01, 0x28, 0x03, 0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a,
|
||||
0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03,
|
||||
0x52, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75,
|
||||
0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50,
|
||||
0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41,
|
||||
0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06,
|
||||
0x49, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73,
|
||||
0x73, 0x75, 0x65, 0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20,
|
||||
0x01, 0x28, 0x0e, 0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65,
|
||||
0x52, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e,
|
||||
0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61,
|
||||
0x74, 0x61, 0x12, 0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e,
|
||||
0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21,
|
||||
0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45,
|
||||
0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74,
|
||||
0x61, 0x52, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74,
|
||||
0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74,
|
||||
0x65, 0x78, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65,
|
||||
0x72, 0x74, 0x65, 0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62,
|
||||
0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74,
|
||||
0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
|
||||
0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c,
|
||||
0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e,
|
||||
0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28,
|
||||
0x0b, 0x32, 0x1f, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
|
||||
0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65,
|
||||
0x72, 0x73, 0x52, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65,
|
||||
0x74, 0x65, 0x72, 0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
|
||||
0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65,
|
||||
0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06,
|
||||
0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65,
|
||||
0x6d, 0x6f, 0x72, 0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c,
|
||||
0x69, 0x73, 0x6d, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c,
|
||||
0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74,
|
||||
0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72,
|
||||
0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05,
|
||||
0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75,
|
||||
0x72, 0x76, 0x65, 0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31,
|
||||
0x39, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a,
|
||||
0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63,
|
||||
0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62,
|
||||
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_cert_proto_rawDescOnce sync.Once
|
||||
file_cert_proto_rawDescData = file_cert_proto_rawDesc
|
||||
file_cert_v1_proto_rawDescOnce sync.Once
|
||||
file_cert_v1_proto_rawDescData = file_cert_v1_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_cert_proto_rawDescGZIP() []byte {
|
||||
file_cert_proto_rawDescOnce.Do(func() {
|
||||
file_cert_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_proto_rawDescData)
|
||||
func file_cert_v1_proto_rawDescGZIP() []byte {
|
||||
file_cert_v1_proto_rawDescOnce.Do(func() {
|
||||
file_cert_v1_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_v1_proto_rawDescData)
|
||||
})
|
||||
return file_cert_proto_rawDescData
|
||||
return file_cert_v1_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_cert_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
|
||||
var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
|
||||
var file_cert_proto_goTypes = []interface{}{
|
||||
var file_cert_v1_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
|
||||
var file_cert_v1_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
|
||||
var file_cert_v1_proto_goTypes = []any{
|
||||
(Curve)(0), // 0: cert.Curve
|
||||
(*RawNebulaCertificate)(nil), // 1: cert.RawNebulaCertificate
|
||||
(*RawNebulaCertificateDetails)(nil), // 2: cert.RawNebulaCertificateDetails
|
||||
@ -519,7 +519,7 @@ var file_cert_proto_goTypes = []interface{}{
|
||||
(*RawNebulaEncryptionMetadata)(nil), // 4: cert.RawNebulaEncryptionMetadata
|
||||
(*RawNebulaArgon2Parameters)(nil), // 5: cert.RawNebulaArgon2Parameters
|
||||
}
|
||||
var file_cert_proto_depIdxs = []int32{
|
||||
var file_cert_v1_proto_depIdxs = []int32{
|
||||
2, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails
|
||||
0, // 1: cert.RawNebulaCertificateDetails.curve:type_name -> cert.Curve
|
||||
4, // 2: cert.RawNebulaEncryptedData.EncryptionMetadata:type_name -> cert.RawNebulaEncryptionMetadata
|
||||
@ -531,13 +531,13 @@ var file_cert_proto_depIdxs = []int32{
|
||||
0, // [0:4] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_cert_proto_init() }
|
||||
func file_cert_proto_init() {
|
||||
if File_cert_proto != nil {
|
||||
func init() { file_cert_v1_proto_init() }
|
||||
func file_cert_v1_proto_init() {
|
||||
if File_cert_v1_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_cert_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_cert_v1_proto_msgTypes[0].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*RawNebulaCertificate); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@ -549,7 +549,7 @@ func file_cert_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_cert_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_cert_v1_proto_msgTypes[1].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*RawNebulaCertificateDetails); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@ -561,7 +561,7 @@ func file_cert_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_cert_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_cert_v1_proto_msgTypes[2].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*RawNebulaEncryptedData); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@ -573,7 +573,7 @@ func file_cert_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_cert_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_cert_v1_proto_msgTypes[3].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*RawNebulaEncryptionMetadata); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@ -585,7 +585,7 @@ func file_cert_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_cert_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_cert_v1_proto_msgTypes[4].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*RawNebulaArgon2Parameters); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@ -602,19 +602,19 @@ func file_cert_proto_init() {
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_cert_proto_rawDesc,
|
||||
RawDescriptor: file_cert_v1_proto_rawDesc,
|
||||
NumEnums: 1,
|
||||
NumMessages: 5,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
GoTypes: file_cert_proto_goTypes,
|
||||
DependencyIndexes: file_cert_proto_depIdxs,
|
||||
EnumInfos: file_cert_proto_enumTypes,
|
||||
MessageInfos: file_cert_proto_msgTypes,
|
||||
GoTypes: file_cert_v1_proto_goTypes,
|
||||
DependencyIndexes: file_cert_v1_proto_depIdxs,
|
||||
EnumInfos: file_cert_v1_proto_enumTypes,
|
||||
MessageInfos: file_cert_v1_proto_msgTypes,
|
||||
}.Build()
|
||||
File_cert_proto = out.File
|
||||
file_cert_proto_rawDesc = nil
|
||||
file_cert_proto_goTypes = nil
|
||||
file_cert_proto_depIdxs = nil
|
||||
File_cert_v1_proto = out.File
|
||||
file_cert_v1_proto_rawDesc = nil
|
||||
file_cert_v1_proto_goTypes = nil
|
||||
file_cert_v1_proto_depIdxs = nil
|
||||
}
|
||||
218
cert/cert_v1_test.go
Normal file
218
cert/cert_v1_test.go
Normal file
@ -0,0 +1,218 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestCertificateV1_Marshal(t *testing.T) {
|
||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||
|
||||
nc := certificateV1{
|
||||
details: detailsV1{
|
||||
name: "testing",
|
||||
networks: []netip.Prefix{
|
||||
mustParsePrefixUnmapped("10.1.1.1/24"),
|
||||
mustParsePrefixUnmapped("10.1.1.2/16"),
|
||||
},
|
||||
unsafeNetworks: []netip.Prefix{
|
||||
mustParsePrefixUnmapped("9.1.1.2/24"),
|
||||
mustParsePrefixUnmapped("9.1.1.3/16"),
|
||||
},
|
||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||
notBefore: before,
|
||||
notAfter: after,
|
||||
publicKey: pubKey,
|
||||
isCA: false,
|
||||
issuer: "1234567890abcedfghij1234567890ab",
|
||||
},
|
||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
||||
}
|
||||
|
||||
b, err := nc.Marshal()
|
||||
require.NoError(t, err)
|
||||
//t.Log("Cert size:", len(b))
|
||||
|
||||
nc2, err := unmarshalCertificateV1(b, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, Version1, nc.Version())
|
||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||
assert.Equal(t, nc.Signature(), nc2.Signature())
|
||||
assert.Equal(t, nc.Name(), nc2.Name())
|
||||
assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
|
||||
assert.Equal(t, nc.NotAfter(), nc2.NotAfter())
|
||||
assert.Equal(t, nc.PublicKey(), nc2.PublicKey())
|
||||
assert.Equal(t, nc.IsCA(), nc2.IsCA())
|
||||
|
||||
assert.Equal(t, nc.Networks(), nc2.Networks())
|
||||
assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
|
||||
|
||||
assert.Equal(t, nc.Groups(), nc2.Groups())
|
||||
}
|
||||
|
||||
func TestCertificateV1_Expired(t *testing.T) {
|
||||
nc := certificateV1{
|
||||
details: detailsV1{
|
||||
notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
|
||||
notAfter: time.Now().Add(time.Second * 60).Round(time.Second),
|
||||
},
|
||||
}
|
||||
|
||||
assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
|
||||
assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
|
||||
assert.False(t, nc.Expired(time.Now()))
|
||||
}
|
||||
|
||||
func TestCertificateV1_MarshalJSON(t *testing.T) {
|
||||
time.Local = time.UTC
|
||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||
|
||||
nc := certificateV1{
|
||||
details: detailsV1{
|
||||
name: "testing",
|
||||
networks: []netip.Prefix{
|
||||
mustParsePrefixUnmapped("10.1.1.1/24"),
|
||||
mustParsePrefixUnmapped("10.1.1.2/16"),
|
||||
},
|
||||
unsafeNetworks: []netip.Prefix{
|
||||
mustParsePrefixUnmapped("9.1.1.2/24"),
|
||||
mustParsePrefixUnmapped("9.1.1.3/16"),
|
||||
},
|
||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||
notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
|
||||
notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
|
||||
publicKey: pubKey,
|
||||
isCA: false,
|
||||
issuer: "1234567890abcedfghij1234567890ab",
|
||||
},
|
||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
||||
}
|
||||
|
||||
b, err := nc.MarshalJSON()
|
||||
require.NoError(t, err)
|
||||
assert.JSONEq(
|
||||
t,
|
||||
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
|
||||
string(b),
|
||||
)
|
||||
}
|
||||
|
||||
func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
|
||||
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
|
||||
require.Error(t, err)
|
||||
|
||||
c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, b)
|
||||
assert.Equal(t, Curve_CURVE25519, curve)
|
||||
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, priv2 := X25519Keypair()
|
||||
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
||||
err := ca.VerifyPrivateKey(Curve_P256, caKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
|
||||
require.Error(t, err)
|
||||
|
||||
c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, b)
|
||||
assert.Equal(t, Curve_P256, curve)
|
||||
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, priv2 := P256Keypair()
|
||||
err = c.VerifyPrivateKey(Curve_P256, priv2)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// Ensure that upgrading the protobuf library does not change how certificates
|
||||
// are marshalled, since this would break signature verification
|
||||
func TestMarshalingCertificateV1Consistency(t *testing.T) {
|
||||
before := time.Date(1970, time.January, 1, 1, 1, 1, 1, time.UTC)
|
||||
after := time.Date(9999, time.January, 1, 1, 1, 1, 1, time.UTC)
|
||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||
|
||||
nc := certificateV1{
|
||||
details: detailsV1{
|
||||
name: "testing",
|
||||
networks: []netip.Prefix{
|
||||
mustParsePrefixUnmapped("10.1.1.2/16"),
|
||||
mustParsePrefixUnmapped("10.1.1.1/24"),
|
||||
},
|
||||
unsafeNetworks: []netip.Prefix{
|
||||
mustParsePrefixUnmapped("9.1.1.3/16"),
|
||||
mustParsePrefixUnmapped("9.1.1.2/24"),
|
||||
},
|
||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||
notBefore: before,
|
||||
notAfter: after,
|
||||
publicKey: pubKey,
|
||||
isCA: false,
|
||||
issuer: "1234567890abcedfghij1234567890ab",
|
||||
},
|
||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
||||
}
|
||||
|
||||
b, err := nc.Marshal()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
|
||||
|
||||
b, err = proto.Marshal(nc.getRawDetails())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
|
||||
}
|
||||
|
||||
func TestCertificateV1_Copy(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
|
||||
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||
cc := c.Copy()
|
||||
test.AssertDeepCopyEqual(t, c, cc)
|
||||
}
|
||||
|
||||
func TestUnmarshalCertificateV1(t *testing.T) {
|
||||
// Test that we don't panic with an invalid certificate (#332)
|
||||
data := []byte("\x98\x00\x00")
|
||||
_, err := unmarshalCertificateV1(data, nil)
|
||||
require.EqualError(t, err, "encoded Details was nil")
|
||||
}
|
||||
|
||||
func appendByteSlices(b ...[]byte) []byte {
|
||||
retSlice := []byte{}
|
||||
for _, v := range b {
|
||||
retSlice = append(retSlice, v...)
|
||||
}
|
||||
return retSlice
|
||||
}
|
||||
|
||||
func mustParsePrefixUnmapped(s string) netip.Prefix {
|
||||
prefix := netip.MustParsePrefix(s)
|
||||
return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits())
|
||||
}
|
||||
37
cert/cert_v2.asn1
Normal file
37
cert/cert_v2.asn1
Normal file
@ -0,0 +1,37 @@
|
||||
Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN
|
||||
|
||||
Name ::= UTF8String (SIZE (1..253))
|
||||
Time ::= INTEGER (0..18446744073709551615) -- Seconds since unix epoch, uint64 maximum
|
||||
Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length
|
||||
Curve ::= ENUMERATED {
|
||||
curve25519 (0),
|
||||
p256 (1)
|
||||
}
|
||||
|
||||
-- The maximum size of a certificate must not exceed 65536 bytes
|
||||
Certificate ::= SEQUENCE {
|
||||
details OCTET STRING,
|
||||
curve Curve DEFAULT curve25519,
|
||||
publicKey OCTET STRING,
|
||||
-- signature(details + curve + publicKey) using the appropriate method for curve
|
||||
signature OCTET STRING
|
||||
}
|
||||
|
||||
Details ::= SEQUENCE {
|
||||
name Name,
|
||||
|
||||
-- At least 1 ipv4 or ipv6 address must be present if isCA is false
|
||||
networks SEQUENCE OF Network OPTIONAL,
|
||||
unsafeNetworks SEQUENCE OF Network OPTIONAL,
|
||||
groups SEQUENCE OF Name OPTIONAL,
|
||||
isCA BOOLEAN DEFAULT false,
|
||||
notBefore Time,
|
||||
notAfter Time,
|
||||
|
||||
-- issuer is only required if isCA is false, if isCA is true then it must not be present
|
||||
issuer OCTET STRING OPTIONAL,
|
||||
...
|
||||
-- New fields can be added below here
|
||||
}
|
||||
|
||||
END
|
||||
730
cert/cert_v2.go
Normal file
730
cert/cert_v2.go
Normal file
@ -0,0 +1,730 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdh"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/cryptobyte"
|
||||
"golang.org/x/crypto/cryptobyte/asn1"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
)
|
||||
|
||||
const (
|
||||
classConstructed = 0x20
|
||||
classContextSpecific = 0x80
|
||||
|
||||
TagCertDetails = 0 | classConstructed | classContextSpecific
|
||||
TagCertCurve = 1 | classContextSpecific
|
||||
TagCertPublicKey = 2 | classContextSpecific
|
||||
TagCertSignature = 3 | classContextSpecific
|
||||
|
||||
TagDetailsName = 0 | classContextSpecific
|
||||
TagDetailsNetworks = 1 | classConstructed | classContextSpecific
|
||||
TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific
|
||||
TagDetailsGroups = 3 | classConstructed | classContextSpecific
|
||||
TagDetailsIsCA = 4 | classContextSpecific
|
||||
TagDetailsNotBefore = 5 | classContextSpecific
|
||||
TagDetailsNotAfter = 6 | classContextSpecific
|
||||
TagDetailsIssuer = 7 | classContextSpecific
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxCertificateSize is the maximum length a valid certificate can be
|
||||
MaxCertificateSize = 65536
|
||||
|
||||
// MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems
|
||||
MaxNameLength = 253
|
||||
|
||||
// MaxNetworkLength is the maximum length a network value can be.
|
||||
// 16 bytes for an ipv6 address + 1 byte for the prefix length
|
||||
MaxNetworkLength = 17
|
||||
)
|
||||
|
||||
type certificateV2 struct {
|
||||
details detailsV2
|
||||
|
||||
// RawDetails contains the entire asn.1 DER encoded Details struct
|
||||
// This is to benefit forwards compatibility in signature checking.
|
||||
// signature(RawDetails + Curve + PublicKey) == Signature
|
||||
rawDetails []byte
|
||||
curve Curve
|
||||
publicKey []byte
|
||||
signature []byte
|
||||
}
|
||||
|
||||
type detailsV2 struct {
|
||||
name string
|
||||
networks []netip.Prefix // MUST BE SORTED
|
||||
unsafeNetworks []netip.Prefix // MUST BE SORTED
|
||||
groups []string
|
||||
isCA bool
|
||||
notBefore time.Time
|
||||
notAfter time.Time
|
||||
issuer string
|
||||
}
|
||||
|
||||
func (c *certificateV2) Version() Version {
|
||||
return Version2
|
||||
}
|
||||
|
||||
func (c *certificateV2) Curve() Curve {
|
||||
return c.curve
|
||||
}
|
||||
|
||||
func (c *certificateV2) Groups() []string {
|
||||
return c.details.groups
|
||||
}
|
||||
|
||||
func (c *certificateV2) IsCA() bool {
|
||||
return c.details.isCA
|
||||
}
|
||||
|
||||
func (c *certificateV2) Issuer() string {
|
||||
return c.details.issuer
|
||||
}
|
||||
|
||||
func (c *certificateV2) Name() string {
|
||||
return c.details.name
|
||||
}
|
||||
|
||||
func (c *certificateV2) Networks() []netip.Prefix {
|
||||
return c.details.networks
|
||||
}
|
||||
|
||||
func (c *certificateV2) NotAfter() time.Time {
|
||||
return c.details.notAfter
|
||||
}
|
||||
|
||||
func (c *certificateV2) NotBefore() time.Time {
|
||||
return c.details.notBefore
|
||||
}
|
||||
|
||||
func (c *certificateV2) PublicKey() []byte {
|
||||
return c.publicKey
|
||||
}
|
||||
|
||||
func (c *certificateV2) Signature() []byte {
|
||||
return c.signature
|
||||
}
|
||||
|
||||
func (c *certificateV2) UnsafeNetworks() []netip.Prefix {
|
||||
return c.details.unsafeNetworks
|
||||
}
|
||||
|
||||
func (c *certificateV2) Fingerprint() (string, error) {
|
||||
if len(c.rawDetails) == 0 {
|
||||
return "", ErrMissingDetails
|
||||
}
|
||||
|
||||
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)+len(c.signature))
|
||||
copy(b, c.rawDetails)
|
||||
b[len(c.rawDetails)] = byte(c.curve)
|
||||
copy(b[len(c.rawDetails)+1:], c.publicKey)
|
||||
copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature)
|
||||
sum := sha256.Sum256(b)
|
||||
return hex.EncodeToString(sum[:]), nil
|
||||
}
|
||||
|
||||
func (c *certificateV2) CheckSignature(key []byte) bool {
|
||||
if len(c.rawDetails) == 0 {
|
||||
return false
|
||||
}
|
||||
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
|
||||
copy(b, c.rawDetails)
|
||||
b[len(c.rawDetails)] = byte(c.curve)
|
||||
copy(b[len(c.rawDetails)+1:], c.publicKey)
|
||||
|
||||
switch c.curve {
|
||||
case Curve_CURVE25519:
|
||||
return ed25519.Verify(key, b, c.signature)
|
||||
case Curve_P256:
|
||||
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
||||
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
||||
hashed := sha256.Sum256(b)
|
||||
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *certificateV2) Expired(t time.Time) bool {
|
||||
return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
|
||||
}
|
||||
|
||||
func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error {
|
||||
if curve != c.curve {
|
||||
return ErrPublicPrivateCurveMismatch
|
||||
}
|
||||
if c.details.isCA {
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
// the call to PublicKey below will panic slice bounds out of range otherwise
|
||||
if len(key) != ed25519.PrivateKeySize {
|
||||
return ErrInvalidPrivateKey
|
||||
}
|
||||
|
||||
if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
|
||||
return ErrPublicPrivateKeyMismatch
|
||||
}
|
||||
case Curve_P256:
|
||||
privkey, err := ecdh.P256().NewPrivateKey(key)
|
||||
if err != nil {
|
||||
return ErrInvalidPrivateKey
|
||||
}
|
||||
pub := privkey.PublicKey().Bytes()
|
||||
if !bytes.Equal(pub, c.publicKey) {
|
||||
return ErrPublicPrivateKeyMismatch
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("invalid curve: %s", curve)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var pub []byte
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
var err error
|
||||
pub, err = curve25519.X25519(key, curve25519.Basepoint)
|
||||
if err != nil {
|
||||
return ErrInvalidPrivateKey
|
||||
}
|
||||
case Curve_P256:
|
||||
privkey, err := ecdh.P256().NewPrivateKey(key)
|
||||
if err != nil {
|
||||
return ErrInvalidPrivateKey
|
||||
}
|
||||
pub = privkey.PublicKey().Bytes()
|
||||
default:
|
||||
return fmt.Errorf("invalid curve: %s", curve)
|
||||
}
|
||||
if !bytes.Equal(pub, c.publicKey) {
|
||||
return ErrPublicPrivateKeyMismatch
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *certificateV2) String() string {
|
||||
mb, err := c.marshalJSON()
|
||||
if err != nil {
|
||||
return fmt.Sprintf("<error marshalling certificate: %v>", err)
|
||||
}
|
||||
|
||||
b, err := json.MarshalIndent(mb, "", "\t")
|
||||
if err != nil {
|
||||
return fmt.Sprintf("<error marshalling certificate: %v>", err)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (c *certificateV2) MarshalForHandshakes() ([]byte, error) {
|
||||
if c.rawDetails == nil {
|
||||
return nil, ErrEmptyRawDetails
|
||||
}
|
||||
var b cryptobyte.Builder
|
||||
// Outermost certificate
|
||||
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
||||
|
||||
// Add the cert details which is already marshalled
|
||||
b.AddBytes(c.rawDetails)
|
||||
|
||||
// Skipping the curve and public key since those come across in a different part of the handshake
|
||||
|
||||
// Add the signature
|
||||
b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(c.signature)
|
||||
})
|
||||
})
|
||||
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func (c *certificateV2) Marshal() ([]byte, error) {
|
||||
if c.rawDetails == nil {
|
||||
return nil, ErrEmptyRawDetails
|
||||
}
|
||||
var b cryptobyte.Builder
|
||||
// Outermost certificate
|
||||
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
||||
|
||||
// Add the cert details which is already marshalled
|
||||
b.AddBytes(c.rawDetails)
|
||||
|
||||
// Add the curve only if its not the default value
|
||||
if c.curve != Curve_CURVE25519 {
|
||||
b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) {
|
||||
b.AddBytes([]byte{byte(c.curve)})
|
||||
})
|
||||
}
|
||||
|
||||
// Add the public key if it is not empty
|
||||
if c.publicKey != nil {
|
||||
b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(c.publicKey)
|
||||
})
|
||||
}
|
||||
|
||||
// Add the signature
|
||||
b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(c.signature)
|
||||
})
|
||||
})
|
||||
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func (c *certificateV2) MarshalPEM() ([]byte, error) {
|
||||
b, err := c.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil
|
||||
}
|
||||
|
||||
func (c *certificateV2) MarshalJSON() ([]byte, error) {
|
||||
b, err := c.marshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(b)
|
||||
}
|
||||
|
||||
func (c *certificateV2) marshalJSON() (m, error) {
|
||||
fp, err := c.Fingerprint()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m{
|
||||
"details": m{
|
||||
"name": c.details.name,
|
||||
"networks": c.details.networks,
|
||||
"unsafeNetworks": c.details.unsafeNetworks,
|
||||
"groups": c.details.groups,
|
||||
"notBefore": c.details.notBefore,
|
||||
"notAfter": c.details.notAfter,
|
||||
"isCa": c.details.isCA,
|
||||
"issuer": c.details.issuer,
|
||||
},
|
||||
"version": Version2,
|
||||
"publicKey": fmt.Sprintf("%x", c.publicKey),
|
||||
"curve": c.curve.String(),
|
||||
"fingerprint": fp,
|
||||
"signature": fmt.Sprintf("%x", c.Signature()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *certificateV2) Copy() Certificate {
|
||||
nc := &certificateV2{
|
||||
details: detailsV2{
|
||||
name: c.details.name,
|
||||
notBefore: c.details.notBefore,
|
||||
notAfter: c.details.notAfter,
|
||||
isCA: c.details.isCA,
|
||||
issuer: c.details.issuer,
|
||||
},
|
||||
curve: c.curve,
|
||||
publicKey: make([]byte, len(c.publicKey)),
|
||||
signature: make([]byte, len(c.signature)),
|
||||
rawDetails: make([]byte, len(c.rawDetails)),
|
||||
}
|
||||
|
||||
if c.details.groups != nil {
|
||||
nc.details.groups = make([]string, len(c.details.groups))
|
||||
copy(nc.details.groups, c.details.groups)
|
||||
}
|
||||
|
||||
if c.details.networks != nil {
|
||||
nc.details.networks = make([]netip.Prefix, len(c.details.networks))
|
||||
copy(nc.details.networks, c.details.networks)
|
||||
}
|
||||
|
||||
if c.details.unsafeNetworks != nil {
|
||||
nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
|
||||
copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
|
||||
}
|
||||
|
||||
copy(nc.rawDetails, c.rawDetails)
|
||||
copy(nc.signature, c.signature)
|
||||
copy(nc.publicKey, c.publicKey)
|
||||
|
||||
return nc
|
||||
}
|
||||
|
||||
func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error {
|
||||
c.details = detailsV2{
|
||||
name: t.Name,
|
||||
networks: t.Networks,
|
||||
unsafeNetworks: t.UnsafeNetworks,
|
||||
groups: t.Groups,
|
||||
isCA: t.IsCA,
|
||||
notBefore: t.NotBefore,
|
||||
notAfter: t.NotAfter,
|
||||
issuer: t.issuer,
|
||||
}
|
||||
c.curve = t.Curve
|
||||
c.publicKey = t.PublicKey
|
||||
return c.validate()
|
||||
}
|
||||
|
||||
func (c *certificateV2) validate() error {
|
||||
// Empty names are allowed
|
||||
|
||||
if len(c.publicKey) == 0 {
|
||||
return ErrInvalidPublicKey
|
||||
}
|
||||
|
||||
if !c.details.isCA && len(c.details.networks) == 0 {
|
||||
return NewErrInvalidCertificateProperties("non-CA certificate must contain at least 1 network")
|
||||
}
|
||||
|
||||
hasV4Networks := false
|
||||
hasV6Networks := false
|
||||
for _, network := range c.details.networks {
|
||||
if !network.IsValid() || !network.Addr().IsValid() {
|
||||
return NewErrInvalidCertificateProperties("invalid network: %s", network)
|
||||
}
|
||||
|
||||
if network.Addr().IsUnspecified() {
|
||||
return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
|
||||
}
|
||||
|
||||
if network.Addr().Zone() != "" {
|
||||
return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
|
||||
}
|
||||
|
||||
if network.Addr().Is4In6() {
|
||||
return NewErrInvalidCertificateProperties("4in6 networks are not allowed: %s", network)
|
||||
}
|
||||
|
||||
hasV4Networks = hasV4Networks || network.Addr().Is4()
|
||||
hasV6Networks = hasV6Networks || network.Addr().Is6()
|
||||
}
|
||||
|
||||
slices.SortFunc(c.details.networks, comparePrefix)
|
||||
err := findDuplicatePrefix(c.details.networks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, network := range c.details.unsafeNetworks {
|
||||
if !network.IsValid() || !network.Addr().IsValid() {
|
||||
return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
|
||||
}
|
||||
|
||||
if network.Addr().Zone() != "" {
|
||||
return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
|
||||
}
|
||||
|
||||
if !c.details.isCA {
|
||||
if network.Addr().Is6() {
|
||||
if !hasV6Networks {
|
||||
return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network)
|
||||
}
|
||||
} else if network.Addr().Is4() {
|
||||
if !hasV4Networks {
|
||||
return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortFunc(c.details.unsafeNetworks, comparePrefix)
|
||||
err = findDuplicatePrefix(c.details.unsafeNetworks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *certificateV2) marshalForSigning() ([]byte, error) {
|
||||
d, err := c.details.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshalling certificate details failed: %w", err)
|
||||
}
|
||||
c.rawDetails = d
|
||||
|
||||
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
|
||||
copy(b, c.rawDetails)
|
||||
b[len(c.rawDetails)] = byte(c.curve)
|
||||
copy(b[len(c.rawDetails)+1:], c.publicKey)
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (c *certificateV2) setSignature(b []byte) error {
|
||||
if len(b) == 0 {
|
||||
return ErrEmptySignature
|
||||
}
|
||||
c.signature = b
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *detailsV2) Marshal() ([]byte, error) {
|
||||
var b cryptobyte.Builder
|
||||
var err error
|
||||
|
||||
// Details are a structure
|
||||
b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) {
|
||||
|
||||
// Add the name
|
||||
b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) {
|
||||
b.AddBytes([]byte(d.name))
|
||||
})
|
||||
|
||||
// Add the networks if any exist
|
||||
if len(d.networks) > 0 {
|
||||
b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) {
|
||||
for _, n := range d.networks {
|
||||
sb, innerErr := n.MarshalBinary()
|
||||
if innerErr != nil {
|
||||
// MarshalBinary never returns an error
|
||||
err = fmt.Errorf("unable to marshal network: %w", innerErr)
|
||||
return
|
||||
}
|
||||
b.AddASN1OctetString(sb)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Add the unsafe networks if any exist
|
||||
if len(d.unsafeNetworks) > 0 {
|
||||
b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) {
|
||||
for _, n := range d.unsafeNetworks {
|
||||
sb, innerErr := n.MarshalBinary()
|
||||
if innerErr != nil {
|
||||
// MarshalBinary never returns an error
|
||||
err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr)
|
||||
return
|
||||
}
|
||||
b.AddASN1OctetString(sb)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Add groups if any exist
|
||||
if len(d.groups) > 0 {
|
||||
b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) {
|
||||
for _, group := range d.groups {
|
||||
b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) {
|
||||
b.AddBytes([]byte(group))
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Add IsCA only if true
|
||||
if d.isCA {
|
||||
b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) {
|
||||
b.AddUint8(0xff)
|
||||
})
|
||||
}
|
||||
|
||||
// Add not before
|
||||
b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore)
|
||||
|
||||
// Add not after
|
||||
b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter)
|
||||
|
||||
// Add the issuer if present
|
||||
if d.issuer != "" {
|
||||
issuerBytes, innerErr := hex.DecodeString(d.issuer)
|
||||
if innerErr != nil {
|
||||
err = fmt.Errorf("failed to decode issuer: %w", innerErr)
|
||||
return
|
||||
}
|
||||
b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(issuerBytes)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) {
|
||||
l := len(b)
|
||||
if l == 0 || l > MaxCertificateSize {
|
||||
return nil, ErrBadFormat
|
||||
}
|
||||
|
||||
input := cryptobyte.String(b)
|
||||
// Open the envelope
|
||||
if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() {
|
||||
return nil, ErrBadFormat
|
||||
}
|
||||
|
||||
// Grab the cert details, we need to preserve the tag and length
|
||||
var rawDetails cryptobyte.String
|
||||
if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() {
|
||||
return nil, ErrBadFormat
|
||||
}
|
||||
|
||||
//Maybe grab the curve
|
||||
var rawCurve byte
|
||||
if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) {
|
||||
return nil, ErrBadFormat
|
||||
}
|
||||
curve = Curve(rawCurve)
|
||||
|
||||
// Maybe grab the public key
|
||||
var rawPublicKey cryptobyte.String
|
||||
if len(publicKey) > 0 {
|
||||
rawPublicKey = publicKey
|
||||
} else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) {
|
||||
return nil, ErrBadFormat
|
||||
}
|
||||
|
||||
if len(rawPublicKey) == 0 {
|
||||
return nil, ErrBadFormat
|
||||
}
|
||||
|
||||
// Grab the signature
|
||||
var rawSignature cryptobyte.String
|
||||
if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() {
|
||||
return nil, ErrBadFormat
|
||||
}
|
||||
|
||||
// Finally unmarshal the details
|
||||
details, err := unmarshalDetails(rawDetails)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &certificateV2{
|
||||
details: details,
|
||||
rawDetails: rawDetails,
|
||||
curve: curve,
|
||||
publicKey: rawPublicKey,
|
||||
signature: rawSignature,
|
||||
}
|
||||
|
||||
err = c.validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func unmarshalDetails(b cryptobyte.String) (detailsV2, error) {
|
||||
// Open the envelope
|
||||
if !b.ReadASN1(&b, TagCertDetails) || b.Empty() {
|
||||
return detailsV2{}, ErrBadFormat
|
||||
}
|
||||
|
||||
// Read the name
|
||||
var name cryptobyte.String
|
||||
if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength {
|
||||
return detailsV2{}, ErrBadFormat
|
||||
}
|
||||
|
||||
// Read the network addresses
|
||||
var subString cryptobyte.String
|
||||
var found bool
|
||||
|
||||
if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) {
|
||||
return detailsV2{}, ErrBadFormat
|
||||
}
|
||||
|
||||
var networks []netip.Prefix
|
||||
var val cryptobyte.String
|
||||
if found {
|
||||
for !subString.Empty() {
|
||||
if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
|
||||
return detailsV2{}, ErrBadFormat
|
||||
}
|
||||
|
||||
var n netip.Prefix
|
||||
if err := n.UnmarshalBinary(val); err != nil {
|
||||
return detailsV2{}, ErrBadFormat
|
||||
}
|
||||
networks = append(networks, n)
|
||||
}
|
||||
}
|
||||
|
||||
// Read out any unsafe networks
|
||||
if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) {
|
||||
return detailsV2{}, ErrBadFormat
|
||||
}
|
||||
|
||||
var unsafeNetworks []netip.Prefix
|
||||
if found {
|
||||
for !subString.Empty() {
|
||||
if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
|
||||
return detailsV2{}, ErrBadFormat
|
||||
}
|
||||
|
||||
var n netip.Prefix
|
||||
if err := n.UnmarshalBinary(val); err != nil {
|
||||
return detailsV2{}, ErrBadFormat
|
||||
}
|
||||
unsafeNetworks = append(unsafeNetworks, n)
|
||||
}
|
||||
}
|
||||
|
||||
// Read out any groups
|
||||
if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) {
|
||||
return detailsV2{}, ErrBadFormat
|
||||
}
|
||||
|
||||
var groups []string
|
||||
if found {
|
||||
for !subString.Empty() {
|
||||
if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() {
|
||||
return detailsV2{}, ErrBadFormat
|
||||
}
|
||||
groups = append(groups, string(val))
|
||||
}
|
||||
}
|
||||
|
||||
// Read out IsCA
|
||||
var isCa bool
|
||||
if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) {
|
||||
return detailsV2{}, ErrBadFormat
|
||||
}
|
||||
|
||||
// Read not before and not after
|
||||
var notBefore int64
|
||||
if !b.ReadASN1Int64WithTag(¬Before, TagDetailsNotBefore) {
|
||||
return detailsV2{}, ErrBadFormat
|
||||
}
|
||||
|
||||
var notAfter int64
|
||||
if !b.ReadASN1Int64WithTag(¬After, TagDetailsNotAfter) {
|
||||
return detailsV2{}, ErrBadFormat
|
||||
}
|
||||
|
||||
// Read issuer
|
||||
var issuer cryptobyte.String
|
||||
if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) {
|
||||
return detailsV2{}, ErrBadFormat
|
||||
}
|
||||
|
||||
return detailsV2{
|
||||
name: string(name),
|
||||
networks: networks,
|
||||
unsafeNetworks: unsafeNetworks,
|
||||
groups: groups,
|
||||
isCA: isCa,
|
||||
notBefore: time.Unix(notBefore, 0),
|
||||
notAfter: time.Unix(notAfter, 0),
|
||||
issuer: hex.EncodeToString(issuer),
|
||||
}, nil
|
||||
}
|
||||
267
cert/cert_v2_test.go
Normal file
267
cert/cert_v2_test.go
Normal file
@ -0,0 +1,267 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCertificateV2_Marshal(t *testing.T) {
|
||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||
|
||||
nc := certificateV2{
|
||||
details: detailsV2{
|
||||
name: "testing",
|
||||
networks: []netip.Prefix{
|
||||
mustParsePrefixUnmapped("10.1.1.2/16"),
|
||||
mustParsePrefixUnmapped("10.1.1.1/24"),
|
||||
},
|
||||
unsafeNetworks: []netip.Prefix{
|
||||
mustParsePrefixUnmapped("9.1.1.3/16"),
|
||||
mustParsePrefixUnmapped("9.1.1.2/24"),
|
||||
},
|
||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||
notBefore: before,
|
||||
notAfter: after,
|
||||
isCA: false,
|
||||
issuer: "1234567890abcdef1234567890abcdef",
|
||||
},
|
||||
signature: []byte("1234567890abcdef1234567890abcdef"),
|
||||
publicKey: pubKey,
|
||||
}
|
||||
|
||||
db, err := nc.details.Marshal()
|
||||
require.NoError(t, err)
|
||||
nc.rawDetails = db
|
||||
|
||||
b, err := nc.Marshal()
|
||||
require.NoError(t, err)
|
||||
//t.Log("Cert size:", len(b))
|
||||
|
||||
nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, Version2, nc.Version())
|
||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||
assert.Equal(t, nc.Signature(), nc2.Signature())
|
||||
assert.Equal(t, nc.Name(), nc2.Name())
|
||||
assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
|
||||
assert.Equal(t, nc.NotAfter(), nc2.NotAfter())
|
||||
assert.Equal(t, nc.PublicKey(), nc2.PublicKey())
|
||||
assert.Equal(t, nc.IsCA(), nc2.IsCA())
|
||||
assert.Equal(t, nc.Issuer(), nc2.Issuer())
|
||||
|
||||
// unmarshalling will sort networks and unsafeNetworks, we need to do the same
|
||||
// but first make sure it fails
|
||||
assert.NotEqual(t, nc.Networks(), nc2.Networks())
|
||||
assert.NotEqual(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
|
||||
|
||||
slices.SortFunc(nc.details.networks, comparePrefix)
|
||||
slices.SortFunc(nc.details.unsafeNetworks, comparePrefix)
|
||||
|
||||
assert.Equal(t, nc.Networks(), nc2.Networks())
|
||||
assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
|
||||
|
||||
assert.Equal(t, nc.Groups(), nc2.Groups())
|
||||
}
|
||||
|
||||
func TestCertificateV2_Expired(t *testing.T) {
|
||||
nc := certificateV2{
|
||||
details: detailsV2{
|
||||
notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
|
||||
notAfter: time.Now().Add(time.Second * 60).Round(time.Second),
|
||||
},
|
||||
}
|
||||
|
||||
assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
|
||||
assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
|
||||
assert.False(t, nc.Expired(time.Now()))
|
||||
}
|
||||
|
||||
func TestCertificateV2_MarshalJSON(t *testing.T) {
|
||||
time.Local = time.UTC
|
||||
pubKey := []byte("1234567890abcedf1234567890abcedf")
|
||||
|
||||
nc := certificateV2{
|
||||
details: detailsV2{
|
||||
name: "testing",
|
||||
networks: []netip.Prefix{
|
||||
mustParsePrefixUnmapped("10.1.1.1/24"),
|
||||
mustParsePrefixUnmapped("10.1.1.2/16"),
|
||||
},
|
||||
unsafeNetworks: []netip.Prefix{
|
||||
mustParsePrefixUnmapped("9.1.1.2/24"),
|
||||
mustParsePrefixUnmapped("9.1.1.3/16"),
|
||||
},
|
||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||
notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
|
||||
notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
|
||||
isCA: false,
|
||||
issuer: "1234567890abcedf1234567890abcedf",
|
||||
},
|
||||
publicKey: pubKey,
|
||||
signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"),
|
||||
}
|
||||
|
||||
b, err := nc.MarshalJSON()
|
||||
require.ErrorIs(t, err, ErrMissingDetails)
|
||||
|
||||
rd, err := nc.details.Marshal()
|
||||
require.NoError(t, err)
|
||||
|
||||
nc.rawDetails = rd
|
||||
b, err = nc.MarshalJSON()
|
||||
require.NoError(t, err)
|
||||
assert.JSONEq(
|
||||
t,
|
||||
"{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}",
|
||||
string(b),
|
||||
)
|
||||
}
|
||||
|
||||
func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
|
||||
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16])
|
||||
require.ErrorIs(t, err, ErrInvalidPrivateKey)
|
||||
|
||||
_, caKey2, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
|
||||
require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
|
||||
|
||||
c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, b)
|
||||
assert.Equal(t, Curve_CURVE25519, curve)
|
||||
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, priv2 := X25519Keypair()
|
||||
err = c.VerifyPrivateKey(Curve_P256, priv2)
|
||||
require.ErrorIs(t, err, ErrPublicPrivateCurveMismatch)
|
||||
|
||||
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
|
||||
require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
|
||||
|
||||
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16])
|
||||
require.ErrorIs(t, err, ErrInvalidPrivateKey)
|
||||
|
||||
ac, ok := c.(*certificateV2)
|
||||
require.True(t, ok)
|
||||
ac.curve = Curve(99)
|
||||
err = c.VerifyPrivateKey(Curve(99), priv2)
|
||||
require.EqualError(t, err, "invalid curve: 99")
|
||||
|
||||
ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
||||
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16])
|
||||
require.ErrorIs(t, err, ErrInvalidPrivateKey)
|
||||
|
||||
c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||
rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)
|
||||
|
||||
err = c.VerifyPrivateKey(Curve_P256, priv[:16])
|
||||
require.ErrorIs(t, err, ErrInvalidPrivateKey)
|
||||
|
||||
err = c.VerifyPrivateKey(Curve_P256, priv)
|
||||
require.ErrorIs(t, err, ErrInvalidPrivateKey)
|
||||
|
||||
aCa, ok := ca2.(*certificateV2)
|
||||
require.True(t, ok)
|
||||
aCa.curve = Curve(99)
|
||||
err = aCa.VerifyPrivateKey(Curve(99), priv2)
|
||||
require.EqualError(t, err, "invalid curve: 99")
|
||||
|
||||
}
|
||||
|
||||
func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
||||
err := ca.VerifyPrivateKey(Curve_P256, caKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
|
||||
require.Error(t, err)
|
||||
|
||||
c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, b)
|
||||
assert.Equal(t, Curve_P256, curve)
|
||||
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, priv2 := P256Keypair()
|
||||
err = c.VerifyPrivateKey(Curve_P256, priv2)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCertificateV2_Copy(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
|
||||
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||
cc := c.Copy()
|
||||
test.AssertDeepCopyEqual(t, c, cc)
|
||||
}
|
||||
|
||||
func TestUnmarshalCertificateV2(t *testing.T) {
|
||||
data := []byte("\x98\x00\x00")
|
||||
_, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519)
|
||||
require.EqualError(t, err, "bad wire format")
|
||||
}
|
||||
|
||||
func TestCertificateV2_marshalForSigningStability(t *testing.T) {
|
||||
before := time.Date(1996, time.May, 5, 0, 0, 0, 0, time.UTC)
|
||||
after := before.Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||
|
||||
nc := certificateV2{
|
||||
details: detailsV2{
|
||||
name: "testing",
|
||||
networks: []netip.Prefix{
|
||||
mustParsePrefixUnmapped("10.1.1.2/16"),
|
||||
mustParsePrefixUnmapped("10.1.1.1/24"),
|
||||
},
|
||||
unsafeNetworks: []netip.Prefix{
|
||||
mustParsePrefixUnmapped("9.1.1.3/16"),
|
||||
mustParsePrefixUnmapped("9.1.1.2/24"),
|
||||
},
|
||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||
notBefore: before,
|
||||
notAfter: after,
|
||||
isCA: false,
|
||||
issuer: "1234567890abcdef1234567890abcdef",
|
||||
},
|
||||
signature: []byte("1234567890abcdef1234567890abcdef"),
|
||||
publicKey: pubKey,
|
||||
}
|
||||
|
||||
const expectedRawDetailsStr = "a070800774657374696e67a10e04050a0101021004050a01010118a20e0405090101031004050901010218a3270c0b746573742d67726f7570310c0b746573742d67726f7570320c0b746573742d67726f7570338504318bef808604318befbc87101234567890abcdef1234567890abcdef"
|
||||
expectedRawDetails, err := hex.DecodeString(expectedRawDetailsStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err := nc.details.Marshal()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedRawDetails, db)
|
||||
|
||||
expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162")
|
||||
b, err := nc.marshalForSigning()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedForSigning, b)
|
||||
}
|
||||
161
cert/crypto.go
161
cert/crypto.go
@ -3,14 +3,28 @@ package cert
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// KDF factors
|
||||
type NebulaEncryptedData struct {
|
||||
EncryptionMetadata NebulaEncryptionMetadata
|
||||
Ciphertext []byte
|
||||
}
|
||||
|
||||
type NebulaEncryptionMetadata struct {
|
||||
EncryptionAlgorithm string
|
||||
Argon2Parameters Argon2Parameters
|
||||
}
|
||||
|
||||
// Argon2Parameters KDF factors
|
||||
type Argon2Parameters struct {
|
||||
version rune
|
||||
Memory uint32 // KiB
|
||||
@ -19,7 +33,7 @@ type Argon2Parameters struct {
|
||||
salt []byte
|
||||
}
|
||||
|
||||
// Returns a new Argon2Parameters object with current version set
|
||||
// NewArgon2Parameters Returns a new Argon2Parameters object with current version set
|
||||
func NewArgon2Parameters(memory uint32, parallelism uint8, iterations uint32) *Argon2Parameters {
|
||||
return &Argon2Parameters{
|
||||
version: argon2.Version,
|
||||
@ -141,3 +155,146 @@ func splitNonceCiphertext(blob []byte, nonceSize int) ([]byte, []byte, error) {
|
||||
|
||||
return blob[:nonceSize], blob[nonceSize:], nil
|
||||
}
|
||||
|
||||
// EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key
|
||||
func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) {
|
||||
ciphertext, err := aes256Encrypt(passphrase, kdfParams, b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b, err = proto.Marshal(&RawNebulaEncryptedData{
|
||||
EncryptionMetadata: &RawNebulaEncryptionMetadata{
|
||||
EncryptionAlgorithm: "AES-256-GCM",
|
||||
Argon2Parameters: &RawNebulaArgon2Parameters{
|
||||
Version: kdfParams.version,
|
||||
Memory: kdfParams.Memory,
|
||||
Parallelism: uint32(kdfParams.Parallelism),
|
||||
Iterations: kdfParams.Iterations,
|
||||
Salt: kdfParams.salt,
|
||||
},
|
||||
},
|
||||
Ciphertext: ciphertext,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil
|
||||
case Curve_P256:
|
||||
return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid curve: %v", curve)
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its
|
||||
// protobuf-generated struct.
|
||||
func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) {
|
||||
if len(b) == 0 {
|
||||
return nil, fmt.Errorf("nil byte array")
|
||||
}
|
||||
var rned RawNebulaEncryptedData
|
||||
err := proto.Unmarshal(b, &rned)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rned.EncryptionMetadata == nil {
|
||||
return nil, fmt.Errorf("encoded EncryptionMetadata was nil")
|
||||
}
|
||||
|
||||
if rned.EncryptionMetadata.Argon2Parameters == nil {
|
||||
return nil, fmt.Errorf("encoded Argon2Parameters was nil")
|
||||
}
|
||||
|
||||
params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ned := NebulaEncryptedData{
|
||||
EncryptionMetadata: NebulaEncryptionMetadata{
|
||||
EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm,
|
||||
Argon2Parameters: *params,
|
||||
},
|
||||
Ciphertext: rned.Ciphertext,
|
||||
}
|
||||
|
||||
return &ned, nil
|
||||
}
|
||||
|
||||
func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) {
|
||||
if params.Version < math.MinInt32 || params.Version > math.MaxInt32 {
|
||||
return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32)
|
||||
}
|
||||
if params.Memory <= 0 || params.Memory > math.MaxUint32 {
|
||||
return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32))
|
||||
}
|
||||
if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 {
|
||||
return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8)
|
||||
}
|
||||
if params.Iterations <= 0 || params.Iterations > math.MaxUint32 {
|
||||
return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32))
|
||||
}
|
||||
|
||||
return &Argon2Parameters{
|
||||
version: params.Version,
|
||||
Memory: params.Memory,
|
||||
Parallelism: uint8(params.Parallelism),
|
||||
Iterations: params.Iterations,
|
||||
salt: params.Salt,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
// DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with
|
||||
// the given passphrase, returning any other bytes b or an error on failure
|
||||
func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) {
|
||||
var curve Curve
|
||||
|
||||
k, r := pem.Decode(b)
|
||||
if k == nil {
|
||||
return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
|
||||
}
|
||||
|
||||
switch k.Type {
|
||||
case EncryptedEd25519PrivateKeyBanner:
|
||||
curve = Curve_CURVE25519
|
||||
case EncryptedECDSAP256PrivateKeyBanner:
|
||||
curve = Curve_P256
|
||||
default:
|
||||
return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
|
||||
}
|
||||
|
||||
ned, err := UnmarshalNebulaEncryptedData(k.Bytes)
|
||||
if err != nil {
|
||||
return curve, nil, r, err
|
||||
}
|
||||
|
||||
var bytes []byte
|
||||
switch ned.EncryptionMetadata.EncryptionAlgorithm {
|
||||
case "AES-256-GCM":
|
||||
bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext)
|
||||
if err != nil {
|
||||
return curve, nil, r, err
|
||||
}
|
||||
default:
|
||||
return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm)
|
||||
}
|
||||
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
if len(bytes) != ed25519.PrivateKeySize {
|
||||
return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize)
|
||||
}
|
||||
case Curve_P256:
|
||||
if len(bytes) != 32 {
|
||||
return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
|
||||
}
|
||||
}
|
||||
|
||||
return curve, bytes, r, nil
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
@ -23,3 +24,90 @@ func TestNewArgon2Parameters(t *testing.T) {
|
||||
Iterations: 1,
|
||||
}, p)
|
||||
}
|
||||
|
||||
func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) {
|
||||
passphrase := []byte("DO NOT USE THIS KEY")
|
||||
privKey := []byte(`# A good key
|
||||
-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
||||
CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
|
||||
oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
|
||||
+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
|
||||
qrlJ69wer3ZUHFXA
|
||||
-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
||||
`)
|
||||
shortKey := []byte(`# A key which, once decrypted, is too short
|
||||
-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
||||
CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7
|
||||
k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe
|
||||
GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs
|
||||
rQr3bdH3Oy/WiYU=
|
||||
-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
||||
`)
|
||||
invalidBanner := []byte(`# Invalid banner (not encrypted)
|
||||
-----BEGIN NEBULA ED25519 PRIVATE KEY-----
|
||||
bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG
|
||||
XgLvodMXZJuaFPssp+WwtA==
|
||||
-----END NEBULA ED25519 PRIVATE KEY-----
|
||||
`)
|
||||
invalidPem := []byte(`# Not a valid PEM format
|
||||
-BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
||||
CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
|
||||
oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
|
||||
+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
|
||||
qrlJ69wer3ZUHFXA
|
||||
-END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
||||
`)
|
||||
|
||||
keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem)
|
||||
|
||||
// Success test case
|
||||
curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, Curve_CURVE25519, curve)
|
||||
assert.Len(t, k, 64)
|
||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||
|
||||
// Fail due to short key
|
||||
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
|
||||
require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||
|
||||
// Fail due to invalid banner
|
||||
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
|
||||
require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
|
||||
// Fail due to ivalid PEM format, because
|
||||
// it's missing the requisite pre-encapsulation boundary.
|
||||
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
|
||||
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
|
||||
// Fail due to invalid passphrase
|
||||
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
|
||||
require.EqualError(t, err, "invalid passphrase or corrupt private key")
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, []byte{}, rest)
|
||||
}
|
||||
|
||||
func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
|
||||
// Having proved that decryption works correctly above, we can test the
|
||||
// encryption function produces a value which can be decrypted
|
||||
passphrase := []byte("passphrase")
|
||||
bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
|
||||
kdfParams := NewArgon2Parameters(64*1024, 4, 3)
|
||||
key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the "key" can be decrypted successfully
|
||||
curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key)
|
||||
assert.Len(t, k, 64)
|
||||
assert.Equal(t, Curve_CURVE25519, curve)
|
||||
assert.Equal(t, []byte{}, rest)
|
||||
require.NoError(t, err)
|
||||
|
||||
// EncryptAndMarshalEd25519PrivateKey does not create any errors itself
|
||||
}
|
||||
|
||||
@ -2,13 +2,48 @@ package cert
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRootExpired = errors.New("root certificate is expired")
|
||||
ErrExpired = errors.New("certificate is expired")
|
||||
ErrNotCA = errors.New("certificate is not a CA")
|
||||
ErrNotSelfSigned = errors.New("certificate is not self-signed")
|
||||
ErrBlockListed = errors.New("certificate is in the block list")
|
||||
ErrSignatureMismatch = errors.New("certificate signature did not match")
|
||||
ErrBadFormat = errors.New("bad wire format")
|
||||
ErrRootExpired = errors.New("root certificate is expired")
|
||||
ErrExpired = errors.New("certificate is expired")
|
||||
ErrNotCA = errors.New("certificate is not a CA")
|
||||
ErrNotSelfSigned = errors.New("certificate is not self-signed")
|
||||
ErrBlockListed = errors.New("certificate is in the block list")
|
||||
ErrFingerprintMismatch = errors.New("certificate fingerprint did not match")
|
||||
ErrSignatureMismatch = errors.New("certificate signature did not match")
|
||||
ErrInvalidPublicKey = errors.New("invalid public key")
|
||||
ErrInvalidPrivateKey = errors.New("invalid private key")
|
||||
ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve")
|
||||
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
|
||||
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
|
||||
ErrCaNotFound = errors.New("could not find ca for the certificate")
|
||||
|
||||
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
|
||||
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
|
||||
ErrInvalidPEMX25519PublicKeyBanner = errors.New("bytes did not contain a proper X25519 public key banner")
|
||||
ErrInvalidPEMX25519PrivateKeyBanner = errors.New("bytes did not contain a proper X25519 private key banner")
|
||||
ErrInvalidPEMEd25519PublicKeyBanner = errors.New("bytes did not contain a proper Ed25519 public key banner")
|
||||
ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner")
|
||||
|
||||
ErrNoPeerStaticKey = errors.New("no peer static key was present")
|
||||
ErrNoPayload = errors.New("provided payload was empty")
|
||||
|
||||
ErrMissingDetails = errors.New("certificate did not contain details")
|
||||
ErrEmptySignature = errors.New("empty signature")
|
||||
ErrEmptyRawDetails = errors.New("empty rawDetails not allowed")
|
||||
)
|
||||
|
||||
type ErrInvalidCertificateProperties struct {
|
||||
str string
|
||||
}
|
||||
|
||||
func NewErrInvalidCertificateProperties(format string, a ...any) error {
|
||||
return &ErrInvalidCertificateProperties{fmt.Sprintf(format, a...)}
|
||||
}
|
||||
|
||||
func (e *ErrInvalidCertificateProperties) Error() string {
|
||||
return e.str
|
||||
}
|
||||
|
||||
141
cert/helper_test.go
Normal file
141
cert/helper_test.go
Normal file
@ -0,0 +1,141 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"crypto/ecdh"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
|
||||
// NewTestCaCert will create a new ca certificate
|
||||
func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
|
||||
var err error
|
||||
var pub, priv []byte
|
||||
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
pub, priv, err = ed25519.GenerateKey(rand.Reader)
|
||||
case Curve_P256:
|
||||
privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
|
||||
priv = privk.D.FillBytes(make([]byte, 32))
|
||||
default:
|
||||
// There is no default to allow the underlying lib to respond with an error
|
||||
}
|
||||
|
||||
if before.IsZero() {
|
||||
before = time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
}
|
||||
if after.IsZero() {
|
||||
after = time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
}
|
||||
|
||||
t := &TBSCertificate{
|
||||
Curve: curve,
|
||||
Version: version,
|
||||
Name: "test ca",
|
||||
NotBefore: time.Unix(before.Unix(), 0),
|
||||
NotAfter: time.Unix(after.Unix(), 0),
|
||||
PublicKey: pub,
|
||||
Networks: networks,
|
||||
UnsafeNetworks: unsafeNetworks,
|
||||
Groups: groups,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
c, err := t.Sign(nil, curve, priv)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pem, err := c.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return c, pub, priv, pem
|
||||
}
|
||||
|
||||
// NewTestCert will generate a signed certificate with the provided details.
|
||||
// Expiry times are defaulted if you do not pass them in
|
||||
func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
|
||||
if before.IsZero() {
|
||||
before = time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
}
|
||||
|
||||
if after.IsZero() {
|
||||
after = time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
}
|
||||
|
||||
if len(networks) == 0 {
|
||||
networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
|
||||
}
|
||||
|
||||
var pub, priv []byte
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
pub, priv = X25519Keypair()
|
||||
case Curve_P256:
|
||||
pub, priv = P256Keypair()
|
||||
default:
|
||||
panic("unknown curve")
|
||||
}
|
||||
|
||||
nc := &TBSCertificate{
|
||||
Version: v,
|
||||
Curve: curve,
|
||||
Name: name,
|
||||
Networks: networks,
|
||||
UnsafeNetworks: unsafeNetworks,
|
||||
Groups: groups,
|
||||
NotBefore: time.Unix(before.Unix(), 0),
|
||||
NotAfter: time.Unix(after.Unix(), 0),
|
||||
PublicKey: pub,
|
||||
IsCA: false,
|
||||
}
|
||||
|
||||
c, err := nc.Sign(ca, ca.Curve(), key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pem, err := c.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return c, pub, MarshalPrivateKeyToPEM(curve, priv), pem
|
||||
}
|
||||
|
||||
func X25519Keypair() ([]byte, []byte) {
|
||||
privkey := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return pubkey, privkey
|
||||
}
|
||||
|
||||
func P256Keypair() ([]byte, []byte) {
|
||||
privkey, err := ecdh.P256().GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
pubkey := privkey.PublicKey()
|
||||
return pubkey.Bytes(), privkey.Bytes()
|
||||
}
|
||||
161
cert/pem.go
Normal file
161
cert/pem.go
Normal file
@ -0,0 +1,161 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
|
||||
const (
|
||||
CertificateBanner = "NEBULA CERTIFICATE"
|
||||
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
||||
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
|
||||
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
|
||||
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
|
||||
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
|
||||
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
|
||||
|
||||
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
|
||||
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
|
||||
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
|
||||
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
|
||||
)
|
||||
|
||||
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
|
||||
// data or an error on failure
|
||||
func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
||||
p, r := pem.Decode(b)
|
||||
if p == nil {
|
||||
return nil, r, ErrInvalidPEMBlock
|
||||
}
|
||||
|
||||
var c Certificate
|
||||
var err error
|
||||
|
||||
switch p.Type {
|
||||
// Implementations must validate the resulting certificate contains valid information
|
||||
case CertificateBanner:
|
||||
c, err = unmarshalCertificateV1(p.Bytes, nil)
|
||||
case CertificateV2Banner:
|
||||
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
|
||||
default:
|
||||
return nil, r, ErrInvalidPEMCertificateBanner
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, r, err
|
||||
}
|
||||
|
||||
return c, r, nil
|
||||
|
||||
}
|
||||
|
||||
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b})
|
||||
case Curve_P256:
|
||||
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||
k, r := pem.Decode(b)
|
||||
if k == nil {
|
||||
return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
|
||||
}
|
||||
var expectedLen int
|
||||
var curve Curve
|
||||
switch k.Type {
|
||||
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
|
||||
expectedLen = 32
|
||||
curve = Curve_CURVE25519
|
||||
case P256PublicKeyBanner:
|
||||
// Uncompressed
|
||||
expectedLen = 65
|
||||
curve = Curve_P256
|
||||
default:
|
||||
return nil, r, 0, fmt.Errorf("bytes did not contain a proper public key banner")
|
||||
}
|
||||
if len(k.Bytes) != expectedLen {
|
||||
return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve)
|
||||
}
|
||||
return k.Bytes, r, curve, nil
|
||||
}
|
||||
|
||||
func MarshalPrivateKeyToPEM(curve Curve, b []byte) []byte {
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b})
|
||||
case Curve_P256:
|
||||
return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b})
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b})
|
||||
case Curve_P256:
|
||||
return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b})
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
|
||||
// consumed data or an error on failure
|
||||
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||
k, r := pem.Decode(b)
|
||||
if k == nil {
|
||||
return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
|
||||
}
|
||||
var expectedLen int
|
||||
var curve Curve
|
||||
switch k.Type {
|
||||
case X25519PrivateKeyBanner:
|
||||
expectedLen = 32
|
||||
curve = Curve_CURVE25519
|
||||
case P256PrivateKeyBanner:
|
||||
expectedLen = 32
|
||||
curve = Curve_P256
|
||||
default:
|
||||
return nil, r, 0, fmt.Errorf("bytes did not contain a proper private key banner")
|
||||
}
|
||||
if len(k.Bytes) != expectedLen {
|
||||
return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve)
|
||||
}
|
||||
return k.Bytes, r, curve, nil
|
||||
}
|
||||
|
||||
func UnmarshalSigningPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||
k, r := pem.Decode(b)
|
||||
if k == nil {
|
||||
return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
|
||||
}
|
||||
var curve Curve
|
||||
switch k.Type {
|
||||
case EncryptedEd25519PrivateKeyBanner:
|
||||
return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted
|
||||
case EncryptedECDSAP256PrivateKeyBanner:
|
||||
return nil, nil, Curve_P256, ErrPrivateKeyEncrypted
|
||||
case Ed25519PrivateKeyBanner:
|
||||
curve = Curve_CURVE25519
|
||||
if len(k.Bytes) != ed25519.PrivateKeySize {
|
||||
return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize)
|
||||
}
|
||||
case ECDSAP256PrivateKeyBanner:
|
||||
curve = Curve_P256
|
||||
if len(k.Bytes) != 32 {
|
||||
return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
|
||||
}
|
||||
default:
|
||||
return nil, r, 0, fmt.Errorf("bytes did not contain a proper Ed25519/ECDSA private key banner")
|
||||
}
|
||||
return k.Bytes, r, curve, nil
|
||||
}
|
||||
293
cert/pem_test.go
Normal file
293
cert/pem_test.go
Normal file
@ -0,0 +1,293 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUnmarshalCertificateFromPEM(t *testing.T) {
|
||||
goodCert := []byte(`
|
||||
# A good cert
|
||||
-----BEGIN NEBULA CERTIFICATE-----
|
||||
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
|
||||
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
|
||||
bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
|
||||
-----END NEBULA CERTIFICATE-----
|
||||
`)
|
||||
badBanner := []byte(`# A bad banner
|
||||
-----BEGIN NOT A NEBULA CERTIFICATE-----
|
||||
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
|
||||
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
|
||||
bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
|
||||
-----END NOT A NEBULA CERTIFICATE-----
|
||||
`)
|
||||
invalidPem := []byte(`# Not a valid PEM format
|
||||
-BEGIN NEBULA CERTIFICATE-----
|
||||
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
|
||||
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
|
||||
bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
|
||||
-END NEBULA CERTIFICATE----`)
|
||||
|
||||
certBundle := appendByteSlices(goodCert, badBanner, invalidPem)
|
||||
|
||||
// Success test case
|
||||
cert, rest, err := UnmarshalCertificateFromPEM(certBundle)
|
||||
assert.NotNil(t, cert)
|
||||
assert.Equal(t, rest, append(badBanner, invalidPem...))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fail due to invalid banner.
|
||||
cert, rest, err = UnmarshalCertificateFromPEM(rest)
|
||||
assert.Nil(t, cert)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
require.EqualError(t, err, "bytes did not contain a proper certificate banner")
|
||||
|
||||
// Fail due to ivalid PEM format, because
|
||||
// it's missing the requisite pre-encapsulation boundary.
|
||||
cert, rest, err = UnmarshalCertificateFromPEM(rest)
|
||||
assert.Nil(t, cert)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||
}
|
||||
|
||||
func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) {
|
||||
privKey := []byte(`# A good key
|
||||
-----BEGIN NEBULA ED25519 PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
|
||||
-----END NEBULA ED25519 PRIVATE KEY-----
|
||||
`)
|
||||
privP256Key := []byte(`# A good key
|
||||
-----BEGIN NEBULA ECDSA P256 PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA ECDSA P256 PRIVATE KEY-----
|
||||
`)
|
||||
shortKey := []byte(`# A short key
|
||||
-----BEGIN NEBULA ED25519 PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
-----END NEBULA ED25519 PRIVATE KEY-----
|
||||
`)
|
||||
invalidBanner := []byte(`# Invalid banner
|
||||
-----BEGIN NOT A NEBULA PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
|
||||
-----END NOT A NEBULA PRIVATE KEY-----
|
||||
`)
|
||||
invalidPem := []byte(`# Not a valid PEM format
|
||||
-BEGIN NEBULA ED25519 PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
|
||||
-END NEBULA ED25519 PRIVATE KEY-----`)
|
||||
|
||||
keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
|
||||
|
||||
// Success test case
|
||||
k, rest, curve, err := UnmarshalSigningPrivateKeyFromPEM(keyBundle)
|
||||
assert.Len(t, k, 64)
|
||||
assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
|
||||
assert.Equal(t, Curve_CURVE25519, curve)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Success test case
|
||||
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
||||
assert.Len(t, k, 32)
|
||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||
assert.Equal(t, Curve_P256, curve)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fail due to short key
|
||||
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||
require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
|
||||
|
||||
// Fail due to invalid banner
|
||||
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner")
|
||||
|
||||
// Fail due to ivalid PEM format, because
|
||||
// it's missing the requisite pre-encapsulation boundary.
|
||||
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||
}
|
||||
|
||||
func TestUnmarshalPrivateKeyFromPEM(t *testing.T) {
|
||||
privKey := []byte(`# A good key
|
||||
-----BEGIN NEBULA X25519 PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA X25519 PRIVATE KEY-----
|
||||
`)
|
||||
privP256Key := []byte(`# A good key
|
||||
-----BEGIN NEBULA P256 PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA P256 PRIVATE KEY-----
|
||||
`)
|
||||
shortKey := []byte(`# A short key
|
||||
-----BEGIN NEBULA X25519 PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
|
||||
-----END NEBULA X25519 PRIVATE KEY-----
|
||||
`)
|
||||
invalidBanner := []byte(`# Invalid banner
|
||||
-----BEGIN NOT A NEBULA PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NOT A NEBULA PRIVATE KEY-----
|
||||
`)
|
||||
invalidPem := []byte(`# Not a valid PEM format
|
||||
-BEGIN NEBULA X25519 PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-END NEBULA X25519 PRIVATE KEY-----`)
|
||||
|
||||
keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
|
||||
|
||||
// Success test case
|
||||
k, rest, curve, err := UnmarshalPrivateKeyFromPEM(keyBundle)
|
||||
assert.Len(t, k, 32)
|
||||
assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
|
||||
assert.Equal(t, Curve_CURVE25519, curve)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Success test case
|
||||
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
||||
assert.Len(t, k, 32)
|
||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||
assert.Equal(t, Curve_P256, curve)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fail due to short key
|
||||
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||
require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
|
||||
|
||||
// Fail due to invalid banner
|
||||
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
require.EqualError(t, err, "bytes did not contain a proper private key banner")
|
||||
|
||||
// Fail due to ivalid PEM format, because
|
||||
// it's missing the requisite pre-encapsulation boundary.
|
||||
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||
}
|
||||
|
||||
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
|
||||
pubKey := []byte(`# A good key
|
||||
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA ED25519 PUBLIC KEY-----
|
||||
`)
|
||||
shortKey := []byte(`# A short key
|
||||
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
|
||||
-----END NEBULA ED25519 PUBLIC KEY-----
|
||||
`)
|
||||
invalidBanner := []byte(`# Invalid banner
|
||||
-----BEGIN NOT A NEBULA PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NOT A NEBULA PUBLIC KEY-----
|
||||
`)
|
||||
invalidPem := []byte(`# Not a valid PEM format
|
||||
-BEGIN NEBULA ED25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-END NEBULA ED25519 PUBLIC KEY-----`)
|
||||
|
||||
keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem)
|
||||
|
||||
// Success test case
|
||||
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
||||
assert.Len(t, k, 32)
|
||||
assert.Equal(t, Curve_CURVE25519, curve)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||
|
||||
// Fail due to short key
|
||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, Curve_CURVE25519, curve)
|
||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||
require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
|
||||
|
||||
// Fail due to invalid banner
|
||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, Curve_CURVE25519, curve)
|
||||
require.EqualError(t, err, "bytes did not contain a proper public key banner")
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
|
||||
// Fail due to ivalid PEM format, because
|
||||
// it's missing the requisite pre-encapsulation boundary.
|
||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, Curve_CURVE25519, curve)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||
}
|
||||
|
||||
func TestUnmarshalX25519PublicKey(t *testing.T) {
|
||||
pubKey := []byte(`# A good key
|
||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA X25519 PUBLIC KEY-----
|
||||
`)
|
||||
pubP256Key := []byte(`# A good key
|
||||
-----BEGIN NEBULA P256 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA P256 PUBLIC KEY-----
|
||||
`)
|
||||
shortKey := []byte(`# A short key
|
||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
|
||||
-----END NEBULA X25519 PUBLIC KEY-----
|
||||
`)
|
||||
invalidBanner := []byte(`# Invalid banner
|
||||
-----BEGIN NOT A NEBULA PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NOT A NEBULA PUBLIC KEY-----
|
||||
`)
|
||||
invalidPem := []byte(`# Not a valid PEM format
|
||||
-BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-END NEBULA X25519 PUBLIC KEY-----`)
|
||||
|
||||
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
|
||||
|
||||
// Success test case
|
||||
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
||||
assert.Len(t, k, 32)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
|
||||
assert.Equal(t, Curve_CURVE25519, curve)
|
||||
|
||||
// Success test case
|
||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||
assert.Len(t, k, 65)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||
assert.Equal(t, Curve_P256, curve)
|
||||
|
||||
// Fail due to short key
|
||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||
require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
|
||||
|
||||
// Fail due to invalid banner
|
||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||
assert.Nil(t, k)
|
||||
require.EqualError(t, err, "bytes did not contain a proper public key banner")
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
|
||||
// Fail due to ivalid PEM format, because
|
||||
// it's missing the requisite pre-encapsulation boundary.
|
||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||
}
|
||||
167
cert/sign.go
Normal file
167
cert/sign.go
Normal file
@ -0,0 +1,167 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TBSCertificate represents a certificate intended to be signed.
|
||||
// It is invalid to use this structure as a Certificate.
|
||||
type TBSCertificate struct {
|
||||
Version Version
|
||||
Name string
|
||||
Networks []netip.Prefix
|
||||
UnsafeNetworks []netip.Prefix
|
||||
Groups []string
|
||||
IsCA bool
|
||||
NotBefore time.Time
|
||||
NotAfter time.Time
|
||||
PublicKey []byte
|
||||
Curve Curve
|
||||
issuer string
|
||||
}
|
||||
|
||||
type beingSignedCertificate interface {
|
||||
// fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation
|
||||
// Implementations must validate the resulting certificate contains valid information
|
||||
fromTBSCertificate(*TBSCertificate) error
|
||||
|
||||
// marshalForSigning returns the bytes that should be signed
|
||||
marshalForSigning() ([]byte, error)
|
||||
|
||||
// setSignature sets the signature for the certificate that has just been signed. The signature must not be blank.
|
||||
setSignature([]byte) error
|
||||
}
|
||||
|
||||
type SignerLambda func(certBytes []byte) ([]byte, error)
|
||||
|
||||
// Sign will create a sealed certificate using details provided by the TBSCertificate as long as those
|
||||
// details do not violate constraints of the signing certificate.
|
||||
// If the TBSCertificate is a CA then signer must be nil.
|
||||
func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) {
|
||||
switch t.Curve {
|
||||
case Curve_CURVE25519:
|
||||
pk := ed25519.PrivateKey(key)
|
||||
sp := func(certBytes []byte) ([]byte, error) {
|
||||
sig := ed25519.Sign(pk, certBytes)
|
||||
return sig, nil
|
||||
}
|
||||
return t.SignWith(signer, curve, sp)
|
||||
case Curve_P256:
|
||||
pk := &ecdsa.PrivateKey{
|
||||
PublicKey: ecdsa.PublicKey{
|
||||
Curve: elliptic.P256(),
|
||||
},
|
||||
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
|
||||
D: new(big.Int).SetBytes(key),
|
||||
}
|
||||
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
|
||||
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
|
||||
sp := func(certBytes []byte) ([]byte, error) {
|
||||
// We need to hash first for ECDSA
|
||||
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
|
||||
hashed := sha256.Sum256(certBytes)
|
||||
return ecdsa.SignASN1(rand.Reader, pk, hashed[:])
|
||||
}
|
||||
return t.SignWith(signer, curve, sp)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid curve: %s", t.Curve)
|
||||
}
|
||||
}
|
||||
|
||||
// SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature.
|
||||
// You should only use SignWith if you do not have direct access to your private key.
|
||||
func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) {
|
||||
if curve != t.Curve {
|
||||
return nil, fmt.Errorf("curve in cert and private key supplied don't match")
|
||||
}
|
||||
|
||||
if signer != nil {
|
||||
if t.IsCA {
|
||||
return nil, fmt.Errorf("can not sign a CA certificate with another")
|
||||
}
|
||||
|
||||
err := checkCAConstraints(signer, t.NotBefore, t.NotAfter, t.Groups, t.Networks, t.UnsafeNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
issuer, err := signer.Fingerprint()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error computing issuer: %v", err)
|
||||
}
|
||||
t.issuer = issuer
|
||||
} else {
|
||||
if !t.IsCA {
|
||||
return nil, fmt.Errorf("self signed certificates must have IsCA set to true")
|
||||
}
|
||||
}
|
||||
|
||||
var c beingSignedCertificate
|
||||
switch t.Version {
|
||||
case Version1:
|
||||
c = &certificateV1{}
|
||||
err := c.fromTBSCertificate(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case Version2:
|
||||
c = &certificateV2{}
|
||||
err := c.fromTBSCertificate(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown cert version %d", t.Version)
|
||||
}
|
||||
|
||||
certBytes, err := c.marshalForSigning()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sig, err := sp(certBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = c.setSignature(sig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sc, ok := c.(Certificate)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid certificate")
|
||||
}
|
||||
|
||||
return sc, nil
|
||||
}
|
||||
|
||||
func comparePrefix(a, b netip.Prefix) int {
|
||||
addr := a.Addr().Compare(b.Addr())
|
||||
if addr == 0 {
|
||||
return a.Bits() - b.Bits()
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// findDuplicatePrefix returns an error if there is a duplicate prefix in the pre-sorted input slice sortedPrefixes
|
||||
func findDuplicatePrefix(sortedPrefixes []netip.Prefix) error {
|
||||
if len(sortedPrefixes) < 2 {
|
||||
return nil
|
||||
}
|
||||
for i := 1; i < len(sortedPrefixes); i++ {
|
||||
if comparePrefix(sortedPrefixes[i], sortedPrefixes[i-1]) == 0 {
|
||||
return NewErrInvalidCertificateProperties("duplicate network detected: %v", sortedPrefixes[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
91
cert/sign_test.go
Normal file
91
cert/sign_test.go
Normal file
@ -0,0 +1,91 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCertificateV1_Sign(t *testing.T) {
|
||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||
|
||||
tbs := TBSCertificate{
|
||||
Version: Version1,
|
||||
Name: "testing",
|
||||
Networks: []netip.Prefix{
|
||||
mustParsePrefixUnmapped("10.1.1.1/24"),
|
||||
mustParsePrefixUnmapped("10.1.1.2/16"),
|
||||
},
|
||||
UnsafeNetworks: []netip.Prefix{
|
||||
mustParsePrefixUnmapped("9.1.1.2/24"),
|
||||
mustParsePrefixUnmapped("9.1.1.3/24"),
|
||||
},
|
||||
Groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||
NotBefore: before,
|
||||
NotAfter: after,
|
||||
PublicKey: pubKey,
|
||||
IsCA: false,
|
||||
}
|
||||
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, c)
|
||||
assert.True(t, c.CheckSignature(pub))
|
||||
|
||||
b, err := c.Marshal()
|
||||
require.NoError(t, err)
|
||||
uc, err := unmarshalCertificateV1(b, nil)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, uc)
|
||||
}
|
||||
|
||||
func TestCertificateV1_SignP256(t *testing.T) {
|
||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
|
||||
|
||||
tbs := TBSCertificate{
|
||||
Version: Version1,
|
||||
Name: "testing",
|
||||
Networks: []netip.Prefix{
|
||||
mustParsePrefixUnmapped("10.1.1.1/24"),
|
||||
mustParsePrefixUnmapped("10.1.1.2/16"),
|
||||
},
|
||||
UnsafeNetworks: []netip.Prefix{
|
||||
mustParsePrefixUnmapped("9.1.1.2/24"),
|
||||
mustParsePrefixUnmapped("9.1.1.3/16"),
|
||||
},
|
||||
Groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||
NotBefore: before,
|
||||
NotAfter: after,
|
||||
PublicKey: pubKey,
|
||||
IsCA: false,
|
||||
Curve: Curve_P256,
|
||||
}
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
|
||||
rawPriv := priv.D.FillBytes(make([]byte, 32))
|
||||
|
||||
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, c)
|
||||
assert.True(t, c.CheckSignature(pub))
|
||||
|
||||
b, err := c.Marshal()
|
||||
require.NoError(t, err)
|
||||
uc, err := unmarshalCertificateV1(b, nil)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, uc)
|
||||
}
|
||||
138
cert_test/cert.go
Normal file
138
cert_test/cert.go
Normal file
@ -0,0 +1,138 @@
|
||||
package cert_test
|
||||
|
||||
import (
|
||||
"crypto/ecdh"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
|
||||
// NewTestCaCert will create a new ca certificate
|
||||
func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
|
||||
var err error
|
||||
var pub, priv []byte
|
||||
|
||||
switch curve {
|
||||
case cert.Curve_CURVE25519:
|
||||
pub, priv, err = ed25519.GenerateKey(rand.Reader)
|
||||
case cert.Curve_P256:
|
||||
privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
|
||||
priv = privk.D.FillBytes(make([]byte, 32))
|
||||
default:
|
||||
// There is no default to allow the underlying lib to respond with an error
|
||||
}
|
||||
|
||||
if before.IsZero() {
|
||||
before = time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
}
|
||||
if after.IsZero() {
|
||||
after = time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
}
|
||||
|
||||
t := &cert.TBSCertificate{
|
||||
Curve: curve,
|
||||
Version: version,
|
||||
Name: "test ca",
|
||||
NotBefore: time.Unix(before.Unix(), 0),
|
||||
NotAfter: time.Unix(after.Unix(), 0),
|
||||
PublicKey: pub,
|
||||
Networks: networks,
|
||||
UnsafeNetworks: unsafeNetworks,
|
||||
Groups: groups,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
c, err := t.Sign(nil, curve, priv)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pem, err := c.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return c, pub, priv, pem
|
||||
}
|
||||
|
||||
// NewTestCert will generate a signed certificate with the provided details.
|
||||
// Expiry times are defaulted if you do not pass them in
|
||||
func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
|
||||
if before.IsZero() {
|
||||
before = time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
}
|
||||
|
||||
if after.IsZero() {
|
||||
after = time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
}
|
||||
|
||||
var pub, priv []byte
|
||||
switch curve {
|
||||
case cert.Curve_CURVE25519:
|
||||
pub, priv = X25519Keypair()
|
||||
case cert.Curve_P256:
|
||||
pub, priv = P256Keypair()
|
||||
default:
|
||||
panic("unknown curve")
|
||||
}
|
||||
|
||||
nc := &cert.TBSCertificate{
|
||||
Version: v,
|
||||
Curve: curve,
|
||||
Name: name,
|
||||
Networks: networks,
|
||||
UnsafeNetworks: unsafeNetworks,
|
||||
Groups: groups,
|
||||
NotBefore: time.Unix(before.Unix(), 0),
|
||||
NotAfter: time.Unix(after.Unix(), 0),
|
||||
PublicKey: pub,
|
||||
IsCA: false,
|
||||
}
|
||||
|
||||
c, err := nc.Sign(ca, ca.Curve(), key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pem, err := c.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
||||
}
|
||||
|
||||
func X25519Keypair() ([]byte, []byte) {
|
||||
privkey := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return pubkey, privkey
|
||||
}
|
||||
|
||||
func P256Keypair() ([]byte, []byte) {
|
||||
privkey, err := ecdh.P256().GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
pubkey := privkey.PublicKey()
|
||||
return pubkey.Bytes(), privkey.Bytes()
|
||||
}
|
||||
@ -4,12 +4,11 @@ import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@ -28,34 +27,43 @@ type caFlags struct {
|
||||
outCertPath *string
|
||||
outQRPath *string
|
||||
groups *string
|
||||
ips *string
|
||||
subnets *string
|
||||
networks *string
|
||||
unsafeNetworks *string
|
||||
argonMemory *uint
|
||||
argonIterations *uint
|
||||
argonParallelism *uint
|
||||
encryption *bool
|
||||
version *uint
|
||||
|
||||
curve *string
|
||||
p11url *string
|
||||
|
||||
// Deprecated options
|
||||
ips *string
|
||||
subnets *string
|
||||
}
|
||||
|
||||
func newCaFlags() *caFlags {
|
||||
cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)}
|
||||
cf.set.Usage = func() {}
|
||||
cf.name = cf.set.String("name", "", "Required: name of the certificate authority")
|
||||
cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use")
|
||||
cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
|
||||
cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
|
||||
cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
|
||||
cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
|
||||
cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
|
||||
cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses")
|
||||
cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets")
|
||||
cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks")
|
||||
cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks")
|
||||
cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase")
|
||||
cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
|
||||
cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
|
||||
cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
|
||||
cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
|
||||
cf.p11url = p11Flag(cf.set)
|
||||
|
||||
cf.ips = cf.set.String("ips", "", "Deprecated, see -networks")
|
||||
cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
|
||||
return &cf
|
||||
}
|
||||
|
||||
@ -114,38 +122,51 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
||||
}
|
||||
}
|
||||
|
||||
var ips []*net.IPNet
|
||||
if *cf.ips != "" {
|
||||
for _, rs := range strings.Split(*cf.ips, ",") {
|
||||
version := cert.Version(*cf.version)
|
||||
if version != cert.Version1 && version != cert.Version2 {
|
||||
return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
|
||||
}
|
||||
|
||||
var networks []netip.Prefix
|
||||
if *cf.networks == "" && *cf.ips != "" {
|
||||
// Pull up deprecated -ips flag if needed
|
||||
*cf.networks = *cf.ips
|
||||
}
|
||||
|
||||
if *cf.networks != "" {
|
||||
for _, rs := range strings.Split(*cf.networks, ",") {
|
||||
rs := strings.Trim(rs, " ")
|
||||
if rs != "" {
|
||||
ip, ipNet, err := net.ParseCIDR(rs)
|
||||
n, err := netip.ParsePrefix(rs)
|
||||
if err != nil {
|
||||
return newHelpErrorf("invalid ip definition: %s", err)
|
||||
return newHelpErrorf("invalid -networks definition: %s", rs)
|
||||
}
|
||||
if ip.To4() == nil {
|
||||
return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs)
|
||||
if version == cert.Version1 && !n.Addr().Is4() {
|
||||
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs)
|
||||
}
|
||||
|
||||
ipNet.IP = ip
|
||||
ips = append(ips, ipNet)
|
||||
networks = append(networks, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var subnets []*net.IPNet
|
||||
if *cf.subnets != "" {
|
||||
for _, rs := range strings.Split(*cf.subnets, ",") {
|
||||
var unsafeNetworks []netip.Prefix
|
||||
if *cf.unsafeNetworks == "" && *cf.subnets != "" {
|
||||
// Pull up deprecated -subnets flag if needed
|
||||
*cf.unsafeNetworks = *cf.subnets
|
||||
}
|
||||
|
||||
if *cf.unsafeNetworks != "" {
|
||||
for _, rs := range strings.Split(*cf.unsafeNetworks, ",") {
|
||||
rs := strings.Trim(rs, " ")
|
||||
if rs != "" {
|
||||
_, s, err := net.ParseCIDR(rs)
|
||||
n, err := netip.ParsePrefix(rs)
|
||||
if err != nil {
|
||||
return newHelpErrorf("invalid subnet definition: %s", err)
|
||||
return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
|
||||
}
|
||||
if s.IP.To4() == nil {
|
||||
return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
|
||||
if version == cert.Version1 && !n.Addr().Is4() {
|
||||
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs)
|
||||
}
|
||||
subnets = append(subnets, s)
|
||||
unsafeNetworks = append(unsafeNetworks, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -224,19 +245,17 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
||||
}
|
||||
}
|
||||
|
||||
nc := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: *cf.name,
|
||||
Groups: groups,
|
||||
Ips: ips,
|
||||
Subnets: subnets,
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(*cf.duration),
|
||||
PublicKey: pub,
|
||||
IsCA: true,
|
||||
Curve: curve,
|
||||
},
|
||||
Pkcs11Backed: isP11,
|
||||
t := &cert.TBSCertificate{
|
||||
Version: version,
|
||||
Name: *cf.name,
|
||||
Groups: groups,
|
||||
Networks: networks,
|
||||
UnsafeNetworks: unsafeNetworks,
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(*cf.duration),
|
||||
PublicKey: pub,
|
||||
IsCA: true,
|
||||
Curve: curve,
|
||||
}
|
||||
|
||||
if !isP11 {
|
||||
@ -249,15 +268,16 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
||||
return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
|
||||
}
|
||||
|
||||
var c cert.Certificate
|
||||
var b []byte
|
||||
|
||||
if isP11 {
|
||||
err = nc.SignPkcs11(curve, p11Client)
|
||||
c, err = t.SignWith(nil, curve, p11Client.SignASN1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while signing with PKCS#11: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = nc.Sign(curve, rawPriv)
|
||||
c, err = t.Sign(nil, curve, rawPriv)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while signing: %s", err)
|
||||
}
|
||||
@ -268,19 +288,16 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
||||
return fmt.Errorf("error while encrypting out-key: %s", err)
|
||||
}
|
||||
} else {
|
||||
b = cert.MarshalSigningPrivateKey(curve, rawPriv)
|
||||
b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv)
|
||||
}
|
||||
|
||||
err = os.WriteFile(*cf.outKeyPath, b, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-key: %s", err)
|
||||
}
|
||||
if _, err := os.Stat(*cf.outCertPath); err == nil {
|
||||
return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
|
||||
}
|
||||
}
|
||||
|
||||
b, err = nc.MarshalToPEM()
|
||||
b, err = c.MarshalPEM()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while marshalling certificate: %s", err)
|
||||
}
|
||||
|
||||
@ -14,10 +14,9 @@ import (
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
//TODO: test file permissions
|
||||
|
||||
func Test_caSummary(t *testing.T) {
|
||||
assert.Equal(t, "ca <flags>: create a self signed certificate authority", caSummary())
|
||||
}
|
||||
@ -43,9 +42,11 @@ func Test_caHelp(t *testing.T) {
|
||||
" -groups string\n"+
|
||||
" \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
|
||||
" -ips string\n"+
|
||||
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+
|
||||
" Deprecated, see -networks\n"+
|
||||
" -name string\n"+
|
||||
" \tRequired: name of the certificate authority\n"+
|
||||
" -networks string\n"+
|
||||
" \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+
|
||||
" -out-crt string\n"+
|
||||
" \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
|
||||
" -out-key string\n"+
|
||||
@ -54,7 +55,11 @@ func Test_caHelp(t *testing.T) {
|
||||
" \tOptional: output a qr code image (png) of the certificate\n"+
|
||||
optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+
|
||||
" -subnets string\n"+
|
||||
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n",
|
||||
" \tDeprecated, see -unsafe-networks\n"+
|
||||
" -unsafe-networks string\n"+
|
||||
" \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+
|
||||
" -version uint\n"+
|
||||
" \tOptional: version of the certificate format to use (default 2)\n",
|
||||
ob.String(),
|
||||
)
|
||||
}
|
||||
@ -83,85 +88,86 @@ func Test_ca(t *testing.T) {
|
||||
|
||||
// required args
|
||||
assertHelpError(t, ca(
|
||||
[]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
|
||||
[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
|
||||
), "-name is required")
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
|
||||
// ipv4 only ips
|
||||
assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
|
||||
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
|
||||
// ipv4 only subnets
|
||||
assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
|
||||
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
|
||||
// failed key write
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
|
||||
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
||||
args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
|
||||
require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
|
||||
// create temp key file
|
||||
keyF, err := os.CreateTemp("", "test.key")
|
||||
assert.Nil(t, err)
|
||||
os.Remove(keyF.Name())
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.Remove(keyF.Name()))
|
||||
|
||||
// failed cert write
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
|
||||
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
|
||||
require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
|
||||
// create temp cert file
|
||||
crtF, err := os.CreateTemp("", "test.crt")
|
||||
assert.Nil(t, err)
|
||||
os.Remove(crtF.Name())
|
||||
os.Remove(keyF.Name())
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.Remove(crtF.Name()))
|
||||
require.NoError(t, os.Remove(keyF.Name()))
|
||||
|
||||
// test proper cert with removed empty groups and subnets
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
assert.Nil(t, ca(args, ob, eb, nopw))
|
||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.NoError(t, ca(args, ob, eb, nopw))
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
|
||||
// read cert and key files
|
||||
rb, _ := os.ReadFile(keyF.Name())
|
||||
lKey, b, err := cert.UnmarshalEd25519PrivateKey(rb)
|
||||
assert.Len(t, b, 0)
|
||||
assert.Nil(t, err)
|
||||
lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb)
|
||||
assert.Equal(t, cert.Curve_CURVE25519, c)
|
||||
assert.Empty(t, b)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, lKey, 64)
|
||||
|
||||
rb, _ = os.ReadFile(crtF.Name())
|
||||
lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb)
|
||||
assert.Len(t, b, 0)
|
||||
assert.Nil(t, err)
|
||||
lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
|
||||
assert.Empty(t, b)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test", lCrt.Details.Name)
|
||||
assert.Len(t, lCrt.Details.Ips, 0)
|
||||
assert.True(t, lCrt.Details.IsCA)
|
||||
assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups)
|
||||
assert.Len(t, lCrt.Details.Subnets, 0)
|
||||
assert.Len(t, lCrt.Details.PublicKey, 32)
|
||||
assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore))
|
||||
assert.Equal(t, "", lCrt.Details.Issuer)
|
||||
assert.True(t, lCrt.CheckSignature(lCrt.Details.PublicKey))
|
||||
assert.Equal(t, "test", lCrt.Name())
|
||||
assert.Empty(t, lCrt.Networks())
|
||||
assert.True(t, lCrt.IsCA())
|
||||
assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups())
|
||||
assert.Empty(t, lCrt.UnsafeNetworks())
|
||||
assert.Len(t, lCrt.PublicKey(), 32)
|
||||
assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
|
||||
assert.Equal(t, "", lCrt.Issuer())
|
||||
assert.True(t, lCrt.CheckSignature(lCrt.PublicKey()))
|
||||
|
||||
// test encrypted key
|
||||
os.Remove(keyF.Name())
|
||||
os.Remove(crtF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
assert.Nil(t, ca(args, ob, eb, testpw))
|
||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.NoError(t, ca(args, ob, eb, testpw))
|
||||
assert.Equal(t, pwPromptOb, ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
|
||||
@ -169,7 +175,7 @@ func Test_ca(t *testing.T) {
|
||||
rb, _ = os.ReadFile(keyF.Name())
|
||||
k, _ := pem.Decode(rb)
|
||||
ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
// we won't know salt in advance, so just check start of string
|
||||
assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory)
|
||||
assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism)
|
||||
@ -179,8 +185,8 @@ func Test_ca(t *testing.T) {
|
||||
var curve cert.Curve
|
||||
curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb)
|
||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, b, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, b)
|
||||
assert.Len(t, lKey, 64)
|
||||
|
||||
// test when reading passsword results in an error
|
||||
@ -188,8 +194,8 @@ func Test_ca(t *testing.T) {
|
||||
os.Remove(crtF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
assert.Error(t, ca(args, ob, eb, errpw))
|
||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.Error(t, ca(args, ob, eb, errpw))
|
||||
assert.Equal(t, pwPromptOb, ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
|
||||
@ -198,8 +204,8 @@ func Test_ca(t *testing.T) {
|
||||
os.Remove(crtF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
||||
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
|
||||
assert.Equal(t, "", eb.String())
|
||||
|
||||
@ -208,14 +214,14 @@ func Test_ca(t *testing.T) {
|
||||
os.Remove(crtF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
assert.Nil(t, ca(args, ob, eb, nopw))
|
||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.NoError(t, ca(args, ob, eb, nopw))
|
||||
|
||||
// test that we won't overwrite existing certificate file
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
|
||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
|
||||
@ -223,8 +229,8 @@ func Test_ca(t *testing.T) {
|
||||
os.Remove(keyF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
|
||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
os.Remove(keyF.Name())
|
||||
|
||||
@ -82,12 +82,12 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
|
||||
return fmt.Errorf("error while getting public key: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
|
||||
err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-key: %s", err)
|
||||
}
|
||||
}
|
||||
err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600)
|
||||
err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-pub: %s", err)
|
||||
}
|
||||
|
||||
@ -7,10 +7,9 @@ import (
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
//TODO: test file permissions
|
||||
|
||||
func Test_keygenSummary(t *testing.T) {
|
||||
assert.Equal(t, "keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`", keygenSummary())
|
||||
}
|
||||
@ -49,46 +48,48 @@ func Test_keygen(t *testing.T) {
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"}
|
||||
assert.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
||||
require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
|
||||
// create temp key file
|
||||
keyF, err := os.CreateTemp("", "test.key")
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(keyF.Name())
|
||||
|
||||
// failed pub write
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()}
|
||||
assert.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
|
||||
require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
|
||||
// create temp pub file
|
||||
pubF, err := os.CreateTemp("", "test.pub")
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(pubF.Name())
|
||||
|
||||
// test proper keygen
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()}
|
||||
assert.Nil(t, keygen(args, ob, eb))
|
||||
require.NoError(t, keygen(args, ob, eb))
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
|
||||
// read cert and key files
|
||||
rb, _ := os.ReadFile(keyF.Name())
|
||||
lKey, b, err := cert.UnmarshalX25519PrivateKey(rb)
|
||||
assert.Len(t, b, 0)
|
||||
assert.Nil(t, err)
|
||||
lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
|
||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||
assert.Empty(t, b)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, lKey, 32)
|
||||
|
||||
rb, _ = os.ReadFile(pubF.Name())
|
||||
lPub, b, err := cert.UnmarshalX25519PublicKey(rb)
|
||||
assert.Len(t, b, 0)
|
||||
assert.Nil(t, err)
|
||||
lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb)
|
||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||
assert.Empty(t, b)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, lPub, 32)
|
||||
}
|
||||
|
||||
@ -9,10 +9,9 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
//TODO: all flag parsing continueOnError will print to stderr on its own currently
|
||||
|
||||
func Test_help(t *testing.T) {
|
||||
expected := "Usage of " + os.Args[0] + " <global flags> <mode>:\n" +
|
||||
" Global flags:\n" +
|
||||
@ -81,7 +80,7 @@ func assertHelpError(t *testing.T, err error, msg string) {
|
||||
t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg))
|
||||
}
|
||||
|
||||
assert.EqualError(t, err, msg)
|
||||
require.EqualError(t, err, msg)
|
||||
}
|
||||
|
||||
func optionalPkcs11String(msg string) string {
|
||||
|
||||
@ -45,28 +45,27 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
|
||||
return fmt.Errorf("unable to read cert; %s", err)
|
||||
}
|
||||
|
||||
var c *cert.NebulaCertificate
|
||||
var c cert.Certificate
|
||||
var qrBytes []byte
|
||||
part := 0
|
||||
|
||||
var jsonCerts []cert.Certificate
|
||||
|
||||
for {
|
||||
c, rawCert, err = cert.UnmarshalNebulaCertificateFromPEM(rawCert)
|
||||
c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while unmarshaling cert: %s", err)
|
||||
}
|
||||
|
||||
if *pf.json {
|
||||
b, _ := json.Marshal(c)
|
||||
out.Write(b)
|
||||
out.Write([]byte("\n"))
|
||||
|
||||
jsonCerts = append(jsonCerts, c)
|
||||
} else {
|
||||
out.Write([]byte(c.String()))
|
||||
out.Write([]byte("\n"))
|
||||
_, _ = out.Write([]byte(c.String()))
|
||||
_, _ = out.Write([]byte("\n"))
|
||||
}
|
||||
|
||||
if *pf.outQRPath != "" {
|
||||
b, err := c.MarshalToPEM()
|
||||
b, err := c.MarshalPEM()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while marshalling cert to PEM: %s", err)
|
||||
}
|
||||
@ -80,6 +79,12 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
|
||||
part++
|
||||
}
|
||||
|
||||
if *pf.json {
|
||||
b, _ := json.Marshal(jsonCerts)
|
||||
_, _ = out.Write(b)
|
||||
_, _ = out.Write([]byte("\n"))
|
||||
}
|
||||
|
||||
if *pf.outQRPath != "" {
|
||||
b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5)
|
||||
if err != nil {
|
||||
|
||||
@ -2,12 +2,17 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_printSummary(t *testing.T) {
|
||||
@ -48,45 +53,106 @@ func Test_printCert(t *testing.T) {
|
||||
err = printCert([]string{"-path", "does_not_exist"}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
|
||||
require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
|
||||
|
||||
// invalid cert at path
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
tf, err := os.CreateTemp("", "print-cert")
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tf.Name())
|
||||
|
||||
tf.WriteString("-----BEGIN NOPE-----")
|
||||
err = printCert([]string{"-path", tf.Name()}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
|
||||
require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
|
||||
|
||||
// test multiple certs
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
tf.Truncate(0)
|
||||
tf.Seek(0, 0)
|
||||
c := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "test",
|
||||
Groups: []string{"hi"},
|
||||
PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
|
||||
},
|
||||
Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
|
||||
}
|
||||
ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil)
|
||||
c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}, nil, []string{"hi"})
|
||||
|
||||
p, _ := c.MarshalToPEM()
|
||||
p, _ := c.MarshalPEM()
|
||||
tf.Write(p)
|
||||
tf.Write(p)
|
||||
tf.Write(p)
|
||||
|
||||
err = printCert([]string{"-path", tf.Name()}, ob, eb)
|
||||
assert.Nil(t, err)
|
||||
fp, _ := c.Fingerprint()
|
||||
pk := hex.EncodeToString(c.PublicKey())
|
||||
sig := hex.EncodeToString(c.Signature())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\n",
|
||||
//"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
|
||||
`{
|
||||
"details": {
|
||||
"curve": "CURVE25519",
|
||||
"groups": [
|
||||
"hi"
|
||||
],
|
||||
"isCa": false,
|
||||
"issuer": "`+c.Issuer()+`",
|
||||
"name": "test",
|
||||
"networks": [
|
||||
"10.0.0.123/8"
|
||||
],
|
||||
"notAfter": "0001-01-01T00:00:00Z",
|
||||
"notBefore": "0001-01-01T00:00:00Z",
|
||||
"publicKey": "`+pk+`",
|
||||
"unsafeNetworks": []
|
||||
},
|
||||
"fingerprint": "`+fp+`",
|
||||
"signature": "`+sig+`",
|
||||
"version": 1
|
||||
}
|
||||
{
|
||||
"details": {
|
||||
"curve": "CURVE25519",
|
||||
"groups": [
|
||||
"hi"
|
||||
],
|
||||
"isCa": false,
|
||||
"issuer": "`+c.Issuer()+`",
|
||||
"name": "test",
|
||||
"networks": [
|
||||
"10.0.0.123/8"
|
||||
],
|
||||
"notAfter": "0001-01-01T00:00:00Z",
|
||||
"notBefore": "0001-01-01T00:00:00Z",
|
||||
"publicKey": "`+pk+`",
|
||||
"unsafeNetworks": []
|
||||
},
|
||||
"fingerprint": "`+fp+`",
|
||||
"signature": "`+sig+`",
|
||||
"version": 1
|
||||
}
|
||||
{
|
||||
"details": {
|
||||
"curve": "CURVE25519",
|
||||
"groups": [
|
||||
"hi"
|
||||
],
|
||||
"isCa": false,
|
||||
"issuer": "`+c.Issuer()+`",
|
||||
"name": "test",
|
||||
"networks": [
|
||||
"10.0.0.123/8"
|
||||
],
|
||||
"notAfter": "0001-01-01T00:00:00Z",
|
||||
"notBefore": "0001-01-01T00:00:00Z",
|
||||
"publicKey": "`+pk+`",
|
||||
"unsafeNetworks": []
|
||||
},
|
||||
"fingerprint": "`+fp+`",
|
||||
"signature": "`+sig+`",
|
||||
"version": 1
|
||||
}
|
||||
`,
|
||||
ob.String(),
|
||||
)
|
||||
assert.Equal(t, "", eb.String())
|
||||
@ -96,26 +162,84 @@ func Test_printCert(t *testing.T) {
|
||||
eb.Reset()
|
||||
tf.Truncate(0)
|
||||
tf.Seek(0, 0)
|
||||
c = cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "test",
|
||||
Groups: []string{"hi"},
|
||||
PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
|
||||
},
|
||||
Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
|
||||
}
|
||||
|
||||
p, _ = c.MarshalToPEM()
|
||||
tf.Write(p)
|
||||
tf.Write(p)
|
||||
tf.Write(p)
|
||||
|
||||
err = printCert([]string{"-json", "-path", tf.Name()}, ob, eb)
|
||||
assert.Nil(t, err)
|
||||
fp, _ = c.Fingerprint()
|
||||
pk = hex.EncodeToString(c.PublicKey())
|
||||
sig = hex.EncodeToString(c.Signature())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n",
|
||||
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
|
||||
`,
|
||||
ob.String(),
|
||||
)
|
||||
assert.Equal(t, "", eb.String())
|
||||
}
|
||||
|
||||
// NewTestCaCert will generate a CA cert
|
||||
func NewTestCaCert(name string, pubKey, privKey []byte, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) {
|
||||
var err error
|
||||
if pubKey == nil || privKey == nil {
|
||||
pubKey, privKey, err = ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
t := &cert.TBSCertificate{
|
||||
Version: cert.Version1,
|
||||
Name: name,
|
||||
NotBefore: time.Unix(before.Unix(), 0),
|
||||
NotAfter: time.Unix(after.Unix(), 0),
|
||||
PublicKey: pubKey,
|
||||
Networks: networks,
|
||||
UnsafeNetworks: unsafeNetworks,
|
||||
Groups: groups,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
c, err := t.Sign(nil, cert.Curve_CURVE25519, privKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return c, privKey
|
||||
}
|
||||
|
||||
func NewTestCert(ca cert.Certificate, signerKey []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) {
|
||||
if before.IsZero() {
|
||||
before = ca.NotBefore()
|
||||
}
|
||||
|
||||
if after.IsZero() {
|
||||
after = ca.NotAfter()
|
||||
}
|
||||
|
||||
if len(networks) == 0 {
|
||||
networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
|
||||
}
|
||||
|
||||
pub, rawPriv := x25519Keypair()
|
||||
nc := &cert.TBSCertificate{
|
||||
Version: cert.Version1,
|
||||
Name: name,
|
||||
Networks: networks,
|
||||
UnsafeNetworks: unsafeNetworks,
|
||||
Groups: groups,
|
||||
NotBefore: time.Unix(before.Unix(), 0),
|
||||
NotAfter: time.Unix(after.Unix(), 0),
|
||||
PublicKey: pub,
|
||||
IsCA: false,
|
||||
}
|
||||
|
||||
c, err := nc.Sign(ca, ca.Curve(), signerKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return c, rawPriv
|
||||
}
|
||||
|
||||
@ -3,10 +3,11 @@ package main
|
||||
import (
|
||||
"crypto/ecdh"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@ -18,36 +19,46 @@ import (
|
||||
)
|
||||
|
||||
type signFlags struct {
|
||||
set *flag.FlagSet
|
||||
caKeyPath *string
|
||||
caCertPath *string
|
||||
name *string
|
||||
ip *string
|
||||
duration *time.Duration
|
||||
inPubPath *string
|
||||
outKeyPath *string
|
||||
outCertPath *string
|
||||
outQRPath *string
|
||||
groups *string
|
||||
subnets *string
|
||||
p11url *string
|
||||
set *flag.FlagSet
|
||||
version *uint
|
||||
caKeyPath *string
|
||||
caCertPath *string
|
||||
name *string
|
||||
networks *string
|
||||
unsafeNetworks *string
|
||||
duration *time.Duration
|
||||
inPubPath *string
|
||||
outKeyPath *string
|
||||
outCertPath *string
|
||||
outQRPath *string
|
||||
groups *string
|
||||
|
||||
p11url *string
|
||||
|
||||
// Deprecated options
|
||||
ip *string
|
||||
subnets *string
|
||||
}
|
||||
|
||||
func newSignFlags() *signFlags {
|
||||
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
|
||||
sf.set.Usage = func() {}
|
||||
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
|
||||
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
|
||||
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
|
||||
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
|
||||
sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert")
|
||||
sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert")
|
||||
sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for")
|
||||
sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
|
||||
sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
|
||||
sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
|
||||
sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
|
||||
sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
|
||||
sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
|
||||
sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for")
|
||||
sf.p11url = p11Flag(sf.set)
|
||||
|
||||
sf.ip = sf.set.String("ip", "", "Deprecated, see -networks")
|
||||
sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
|
||||
return &sf
|
||||
}
|
||||
|
||||
@ -71,32 +82,47 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
if err := mustFlagString("name", sf.name); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := mustFlagString("ip", sf.ip); err != nil {
|
||||
return err
|
||||
}
|
||||
if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
|
||||
return newHelpErrorf("cannot set both -in-pub and -out-key")
|
||||
}
|
||||
|
||||
var v4Networks []netip.Prefix
|
||||
var v6Networks []netip.Prefix
|
||||
if *sf.networks == "" && *sf.ip != "" {
|
||||
// Pull up deprecated -ip flag if needed
|
||||
*sf.networks = *sf.ip
|
||||
}
|
||||
|
||||
if len(*sf.networks) == 0 {
|
||||
return newHelpErrorf("-networks is required")
|
||||
}
|
||||
|
||||
version := cert.Version(*sf.version)
|
||||
if version != 0 && version != cert.Version1 && version != cert.Version2 {
|
||||
return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
|
||||
}
|
||||
|
||||
var curve cert.Curve
|
||||
var caKey []byte
|
||||
|
||||
if !isP11 {
|
||||
var rawCAKey []byte
|
||||
rawCAKey, err := os.ReadFile(*sf.caKeyPath)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while reading ca-key: %s", err)
|
||||
}
|
||||
|
||||
// naively attempt to decode the private key as though it is not encrypted
|
||||
caKey, _, curve, err = cert.UnmarshalSigningPrivateKey(rawCAKey)
|
||||
if err == cert.ErrPrivateKeyEncrypted {
|
||||
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
|
||||
if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
|
||||
// ask for a passphrase until we get one
|
||||
var passphrase []byte
|
||||
for i := 0; i < 5; i++ {
|
||||
out.Write([]byte("Enter passphrase: "))
|
||||
passphrase, err = pr.ReadPassword()
|
||||
|
||||
if err == ErrNoTerminal {
|
||||
if errors.Is(err, ErrNoTerminal) {
|
||||
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("error reading password: %s", err)
|
||||
@ -124,7 +150,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
return fmt.Errorf("error while reading ca-crt: %s", err)
|
||||
}
|
||||
|
||||
caCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCACert)
|
||||
caCert, _, err := cert.UnmarshalCertificateFromPEM(rawCACert)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while parsing ca-crt: %s", err)
|
||||
}
|
||||
@ -135,30 +161,59 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
}
|
||||
}
|
||||
|
||||
issuer, err := caCert.Sha256Sum()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while getting -ca-crt fingerprint: %s", err)
|
||||
}
|
||||
|
||||
if caCert.Expired(time.Now()) {
|
||||
return fmt.Errorf("ca certificate is expired")
|
||||
}
|
||||
|
||||
// if no duration is given, expire one second before the root expires
|
||||
if *sf.duration <= 0 {
|
||||
*sf.duration = time.Until(caCert.Details.NotAfter) - time.Second*1
|
||||
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
|
||||
}
|
||||
|
||||
ip, ipNet, err := net.ParseCIDR(*sf.ip)
|
||||
if err != nil {
|
||||
return newHelpErrorf("invalid ip definition: %s", err)
|
||||
}
|
||||
if ip.To4() == nil {
|
||||
return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip)
|
||||
}
|
||||
ipNet.IP = ip
|
||||
if *sf.networks != "" {
|
||||
for _, rs := range strings.Split(*sf.networks, ",") {
|
||||
rs := strings.Trim(rs, " ")
|
||||
if rs != "" {
|
||||
n, err := netip.ParsePrefix(rs)
|
||||
if err != nil {
|
||||
return newHelpErrorf("invalid -networks definition: %s", rs)
|
||||
}
|
||||
|
||||
groups := []string{}
|
||||
if n.Addr().Is4() {
|
||||
v4Networks = append(v4Networks, n)
|
||||
} else {
|
||||
v6Networks = append(v6Networks, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var v4UnsafeNetworks []netip.Prefix
|
||||
var v6UnsafeNetworks []netip.Prefix
|
||||
if *sf.unsafeNetworks == "" && *sf.subnets != "" {
|
||||
// Pull up deprecated -subnets flag if needed
|
||||
*sf.unsafeNetworks = *sf.subnets
|
||||
}
|
||||
|
||||
if *sf.unsafeNetworks != "" {
|
||||
for _, rs := range strings.Split(*sf.unsafeNetworks, ",") {
|
||||
rs := strings.Trim(rs, " ")
|
||||
if rs != "" {
|
||||
n, err := netip.ParsePrefix(rs)
|
||||
if err != nil {
|
||||
return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
|
||||
}
|
||||
|
||||
if n.Addr().Is4() {
|
||||
v4UnsafeNetworks = append(v4UnsafeNetworks, n)
|
||||
} else {
|
||||
v6UnsafeNetworks = append(v6UnsafeNetworks, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var groups []string
|
||||
if *sf.groups != "" {
|
||||
for _, rg := range strings.Split(*sf.groups, ",") {
|
||||
g := strings.TrimSpace(rg)
|
||||
@ -168,23 +223,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
}
|
||||
}
|
||||
|
||||
subnets := []*net.IPNet{}
|
||||
if *sf.subnets != "" {
|
||||
for _, rs := range strings.Split(*sf.subnets, ",") {
|
||||
rs := strings.Trim(rs, " ")
|
||||
if rs != "" {
|
||||
_, s, err := net.ParseCIDR(rs)
|
||||
if err != nil {
|
||||
return newHelpErrorf("invalid subnet definition: %s", err)
|
||||
}
|
||||
if s.IP.To4() == nil {
|
||||
return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
|
||||
}
|
||||
subnets = append(subnets, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var pub, rawPriv []byte
|
||||
var p11Client *pkclient.PKClient
|
||||
|
||||
@ -205,7 +243,8 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while reading in-pub: %s", err)
|
||||
}
|
||||
pub, _, pubCurve, err = cert.UnmarshalPublicKey(rawPub)
|
||||
|
||||
pub, _, pubCurve, err = cert.UnmarshalPublicKeyFromPEM(rawPub)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while parsing in-pub: %s", err)
|
||||
}
|
||||
@ -221,38 +260,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
pub, rawPriv = newKeypair(curve)
|
||||
}
|
||||
|
||||
nc := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: *sf.name,
|
||||
Ips: []*net.IPNet{ipNet},
|
||||
Groups: groups,
|
||||
Subnets: subnets,
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(*sf.duration),
|
||||
PublicKey: pub,
|
||||
IsCA: false,
|
||||
Issuer: issuer,
|
||||
Curve: curve,
|
||||
},
|
||||
Pkcs11Backed: isP11,
|
||||
}
|
||||
|
||||
if p11Client == nil {
|
||||
err = nc.Sign(curve, caKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while signing: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = nc.SignPkcs11(curve, p11Client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while signing with PKCS#11: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := nc.CheckRootConstrains(caCert); err != nil {
|
||||
return fmt.Errorf("refusing to sign, root certificate constraints violated: %s", err)
|
||||
}
|
||||
|
||||
if *sf.outKeyPath == "" {
|
||||
*sf.outKeyPath = *sf.name + ".key"
|
||||
}
|
||||
@ -265,20 +272,105 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
|
||||
}
|
||||
|
||||
var crts []cert.Certificate
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(*sf.duration)
|
||||
|
||||
if version == 0 || version == cert.Version1 {
|
||||
// Make sure we at least have an ip
|
||||
if len(v4Networks) != 1 {
|
||||
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
|
||||
}
|
||||
|
||||
if version == cert.Version1 {
|
||||
// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
|
||||
if len(v6Networks) > 0 {
|
||||
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
|
||||
}
|
||||
|
||||
if len(v6UnsafeNetworks) > 0 {
|
||||
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
|
||||
}
|
||||
}
|
||||
|
||||
t := &cert.TBSCertificate{
|
||||
Version: cert.Version1,
|
||||
Name: *sf.name,
|
||||
Networks: []netip.Prefix{v4Networks[0]},
|
||||
Groups: groups,
|
||||
UnsafeNetworks: v4UnsafeNetworks,
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
PublicKey: pub,
|
||||
IsCA: false,
|
||||
Curve: curve,
|
||||
}
|
||||
|
||||
var nc cert.Certificate
|
||||
if p11Client == nil {
|
||||
nc, err = t.Sign(caCert, curve, caKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while signing: %w", err)
|
||||
}
|
||||
} else {
|
||||
nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while signing with PKCS#11: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
crts = append(crts, nc)
|
||||
}
|
||||
|
||||
if version == 0 || version == cert.Version2 {
|
||||
t := &cert.TBSCertificate{
|
||||
Version: cert.Version2,
|
||||
Name: *sf.name,
|
||||
Networks: append(v4Networks, v6Networks...),
|
||||
Groups: groups,
|
||||
UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...),
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
PublicKey: pub,
|
||||
IsCA: false,
|
||||
Curve: curve,
|
||||
}
|
||||
|
||||
var nc cert.Certificate
|
||||
if p11Client == nil {
|
||||
nc, err = t.Sign(caCert, curve, caKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while signing: %w", err)
|
||||
}
|
||||
} else {
|
||||
nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while signing with PKCS#11: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
crts = append(crts, nc)
|
||||
}
|
||||
|
||||
if !isP11 && *sf.inPubPath == "" {
|
||||
if _, err := os.Stat(*sf.outKeyPath); err == nil {
|
||||
return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
|
||||
}
|
||||
|
||||
err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
|
||||
err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-key: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
b, err := nc.MarshalToPEM()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while marshalling certificate: %s", err)
|
||||
var b []byte
|
||||
for _, c := range crts {
|
||||
sb, err := c.MarshalPEM()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while marshalling certificate: %s", err)
|
||||
}
|
||||
b = append(b, sb...)
|
||||
}
|
||||
|
||||
err = os.WriteFile(*sf.outCertPath, b, 0600)
|
||||
|
||||
@ -13,11 +13,10 @@ import (
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
|
||||
//TODO: test file permissions
|
||||
|
||||
func Test_signSummary(t *testing.T) {
|
||||
assert.Equal(t, "sign <flags>: create and sign a certificate", signSummary())
|
||||
}
|
||||
@ -39,9 +38,11 @@ func Test_signHelp(t *testing.T) {
|
||||
" -in-pub string\n"+
|
||||
" \tOptional (if out-key not set): path to read a previously generated public key\n"+
|
||||
" -ip string\n"+
|
||||
" \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+
|
||||
" \tDeprecated, see -networks\n"+
|
||||
" -name string\n"+
|
||||
" \tRequired: name of the cert, usually a hostname\n"+
|
||||
" -networks string\n"+
|
||||
" \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+
|
||||
" -out-crt string\n"+
|
||||
" \tOptional: path to write the certificate to\n"+
|
||||
" -out-key string\n"+
|
||||
@ -50,7 +51,11 @@ func Test_signHelp(t *testing.T) {
|
||||
" \tOptional: output a qr code image (png) of the certificate\n"+
|
||||
optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+
|
||||
" -subnets string\n"+
|
||||
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n",
|
||||
" \tDeprecated, see -unsafe-networks\n"+
|
||||
" -unsafe-networks string\n"+
|
||||
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
|
||||
" -version uint\n"+
|
||||
" \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
|
||||
ob.String(),
|
||||
)
|
||||
}
|
||||
@ -77,20 +82,20 @@ func Test_signCert(t *testing.T) {
|
||||
|
||||
// required args
|
||||
assertHelpError(t, signCert(
|
||||
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
|
||||
[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
|
||||
), "-name is required")
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
assertHelpError(t, signCert(
|
||||
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
|
||||
), "-ip is required")
|
||||
[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
|
||||
), "-networks is required")
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// cannot set -in-pub and -out-key
|
||||
assertHelpError(t, signCert(
|
||||
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
|
||||
[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
|
||||
), "cannot set both -in-pub and -out-key")
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
@ -98,18 +103,18 @@ func Test_signCert(t *testing.T) {
|
||||
// failed to read key
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
|
||||
args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
|
||||
|
||||
// failed to unmarshal key
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
caKeyF, err := os.CreateTemp("", "sign-cert.key")
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(caKeyF.Name())
|
||||
|
||||
args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
|
||||
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
@ -117,11 +122,11 @@ func Test_signCert(t *testing.T) {
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
|
||||
caKeyF.Write(cert.MarshalEd25519PrivateKey(caPriv))
|
||||
caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv))
|
||||
|
||||
// failed to read cert
|
||||
args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
|
||||
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
@ -129,30 +134,22 @@ func Test_signCert(t *testing.T) {
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
caCrtF, err := os.CreateTemp("", "sign-cert.crt")
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(caCrtF.Name())
|
||||
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// write a proper ca cert for later
|
||||
ca := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "ca",
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Minute * 200),
|
||||
PublicKey: caPub,
|
||||
IsCA: true,
|
||||
},
|
||||
}
|
||||
b, _ := ca.MarshalToPEM()
|
||||
ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil)
|
||||
b, _ := ca.MarshalPEM()
|
||||
caCrtF.Write(b)
|
||||
|
||||
// failed to read pub
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
|
||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
@ -160,11 +157,11 @@ func Test_signCert(t *testing.T) {
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
inPubF, err := os.CreateTemp("", "in.pub")
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(inPubF.Name())
|
||||
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
|
||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
@ -172,116 +169,124 @@ func Test_signCert(t *testing.T) {
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
inPub, _ := x25519Keypair()
|
||||
inPubF.Write(cert.MarshalX25519PublicKey(inPub))
|
||||
inPubF.Write(cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub))
|
||||
|
||||
// bad ip cidr
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: invalid CIDR address: a1.1.1.1/24")
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24")
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// bad subnet cidr
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
|
||||
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: invalid CIDR address: a")
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
|
||||
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a")
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
|
||||
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
|
||||
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// mismatched ca key
|
||||
_, caPriv2, _ := ed25519.GenerateKey(rand.Reader)
|
||||
caKeyF2, err := os.CreateTemp("", "sign-cert-2.key")
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(caKeyF2.Name())
|
||||
caKeyF2.Write(cert.MarshalEd25519PrivateKey(caPriv2))
|
||||
caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2))
|
||||
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
|
||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// failed key write
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
|
||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// create temp key file
|
||||
keyF, err := os.CreateTemp("", "test.key")
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
os.Remove(keyF.Name())
|
||||
|
||||
// failed cert write
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
|
||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
os.Remove(keyF.Name())
|
||||
|
||||
// create temp cert file
|
||||
crtF, err := os.CreateTemp("", "test.crt")
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
os.Remove(crtF.Name())
|
||||
|
||||
// test proper cert with removed empty groups and subnets
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
assert.Nil(t, signCert(args, ob, eb, nopw))
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.NoError(t, signCert(args, ob, eb, nopw))
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// read cert and key files
|
||||
rb, _ := os.ReadFile(keyF.Name())
|
||||
lKey, b, err := cert.UnmarshalX25519PrivateKey(rb)
|
||||
assert.Len(t, b, 0)
|
||||
assert.Nil(t, err)
|
||||
lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
|
||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||
assert.Empty(t, b)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, lKey, 32)
|
||||
|
||||
rb, _ = os.ReadFile(crtF.Name())
|
||||
lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb)
|
||||
assert.Len(t, b, 0)
|
||||
assert.Nil(t, err)
|
||||
lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
|
||||
assert.Empty(t, b)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test", lCrt.Details.Name)
|
||||
assert.Equal(t, "1.1.1.1/24", lCrt.Details.Ips[0].String())
|
||||
assert.Len(t, lCrt.Details.Ips, 1)
|
||||
assert.False(t, lCrt.Details.IsCA)
|
||||
assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups)
|
||||
assert.Len(t, lCrt.Details.Subnets, 3)
|
||||
assert.Len(t, lCrt.Details.PublicKey, 32)
|
||||
assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore))
|
||||
assert.Equal(t, "test", lCrt.Name())
|
||||
assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String())
|
||||
assert.Len(t, lCrt.Networks(), 1)
|
||||
assert.False(t, lCrt.IsCA())
|
||||
assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups())
|
||||
assert.Len(t, lCrt.UnsafeNetworks(), 3)
|
||||
assert.Len(t, lCrt.PublicKey(), 32)
|
||||
assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
|
||||
|
||||
sns := []string{}
|
||||
for _, sn := range lCrt.Details.Subnets {
|
||||
for _, sn := range lCrt.UnsafeNetworks() {
|
||||
sns = append(sns, sn.String())
|
||||
}
|
||||
assert.Equal(t, []string{"10.1.1.1/32", "10.2.2.2/32", "10.5.5.5/32"}, sns)
|
||||
|
||||
issuer, _ := ca.Sha256Sum()
|
||||
assert.Equal(t, issuer, lCrt.Details.Issuer)
|
||||
issuer, _ := ca.Fingerprint()
|
||||
assert.Equal(t, issuer, lCrt.Issuer())
|
||||
|
||||
assert.True(t, lCrt.CheckSignature(caPub))
|
||||
|
||||
@ -290,53 +295,55 @@ func Test_signCert(t *testing.T) {
|
||||
os.Remove(crtF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
|
||||
assert.Nil(t, signCert(args, ob, eb, nopw))
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
|
||||
require.NoError(t, signCert(args, ob, eb, nopw))
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// read cert file and check pub key matches in-pub
|
||||
rb, _ = os.ReadFile(crtF.Name())
|
||||
lCrt, b, err = cert.UnmarshalNebulaCertificateFromPEM(rb)
|
||||
assert.Len(t, b, 0)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, lCrt.Details.PublicKey, inPub)
|
||||
lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb)
|
||||
assert.Empty(t, b)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, lCrt.PublicKey(), inPub)
|
||||
|
||||
// test refuse to sign cert with duration beyond root
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate constraints violated: certificate expires after signing certificate")
|
||||
os.Remove(keyF.Name())
|
||||
os.Remove(crtF.Name())
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// create valid cert/key for overwrite tests
|
||||
os.Remove(keyF.Name())
|
||||
os.Remove(crtF.Name())
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
assert.Nil(t, signCert(args, ob, eb, nopw))
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.NoError(t, signCert(args, ob, eb, nopw))
|
||||
|
||||
// test that we won't overwrite existing key file
|
||||
os.Remove(crtF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// create valid cert/key for overwrite tests
|
||||
os.Remove(keyF.Name())
|
||||
os.Remove(crtF.Name())
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
assert.Nil(t, signCert(args, ob, eb, nopw))
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.NoError(t, signCert(args, ob, eb, nopw))
|
||||
|
||||
// test that we won't overwrite existing certificate file
|
||||
os.Remove(keyF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
@ -349,11 +356,11 @@ func Test_signCert(t *testing.T) {
|
||||
eb.Reset()
|
||||
|
||||
caKeyF, err = os.CreateTemp("", "sign-cert.key")
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(caKeyF.Name())
|
||||
|
||||
caCrtF, err = os.CreateTemp("", "sign-cert.crt")
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(caCrtF.Name())
|
||||
|
||||
// generate the encrypted key
|
||||
@ -362,21 +369,13 @@ func Test_signCert(t *testing.T) {
|
||||
b, _ = cert.EncryptAndMarshalSigningPrivateKey(cert.Curve_CURVE25519, caPriv, passphrase, kdfParams)
|
||||
caKeyF.Write(b)
|
||||
|
||||
ca = cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "ca",
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Minute * 200),
|
||||
PublicKey: caPub,
|
||||
IsCA: true,
|
||||
},
|
||||
}
|
||||
b, _ = ca.MarshalToPEM()
|
||||
ca, _ = NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil)
|
||||
b, _ = ca.MarshalPEM()
|
||||
caCrtF.Write(b)
|
||||
|
||||
// test with the proper password
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
assert.Nil(t, signCert(args, ob, eb, testpw))
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.NoError(t, signCert(args, ob, eb, testpw))
|
||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
@ -385,8 +384,8 @@ func Test_signCert(t *testing.T) {
|
||||
eb.Reset()
|
||||
|
||||
testpw.password = []byte("invalid password")
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
assert.Error(t, signCert(args, ob, eb, testpw))
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.Error(t, signCert(args, ob, eb, testpw))
|
||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
@ -394,8 +393,8 @@ func Test_signCert(t *testing.T) {
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
assert.Error(t, signCert(args, ob, eb, nopw))
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.Error(t, signCert(args, ob, eb, nopw))
|
||||
// normally the user hitting enter on the prompt would add newlines between these
|
||||
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
@ -404,8 +403,8 @@ func Test_signCert(t *testing.T) {
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
|
||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
assert.Error(t, signCert(args, ob, eb, errpw))
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.Error(t, signCert(args, ob, eb, errpw))
|
||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -41,14 +42,14 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
|
||||
|
||||
rawCACert, err := os.ReadFile(*vf.caPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while reading ca: %s", err)
|
||||
return fmt.Errorf("error while reading ca: %w", err)
|
||||
}
|
||||
|
||||
caPool := cert.NewCAPool()
|
||||
for {
|
||||
rawCACert, err = caPool.AddCACertificate(rawCACert)
|
||||
rawCACert, err = caPool.AddCAFromPEM(rawCACert)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while adding ca cert to pool: %s", err)
|
||||
return fmt.Errorf("error while adding ca cert to pool: %w", err)
|
||||
}
|
||||
|
||||
if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
|
||||
@ -58,20 +59,30 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
|
||||
|
||||
rawCert, err := os.ReadFile(*vf.certPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read crt; %s", err)
|
||||
return fmt.Errorf("unable to read crt: %w", err)
|
||||
}
|
||||
var errs []error
|
||||
for {
|
||||
if len(rawCert) == 0 {
|
||||
break
|
||||
}
|
||||
c, extra, err := cert.UnmarshalCertificateFromPEM(rawCert)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while parsing crt: %w", err)
|
||||
}
|
||||
rawCert = extra
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, cert.ErrCaNotFound):
|
||||
errs = append(errs, fmt.Errorf("error while verifying certificate v%d %s with issuer %s: %w", c.Version(), c.Name(), c.Issuer(), err))
|
||||
default:
|
||||
errs = append(errs, fmt.Errorf("error while verifying certificate %+v: %w", c, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while parsing crt: %s", err)
|
||||
}
|
||||
|
||||
good, err := c.Verify(time.Now(), caPool)
|
||||
if !good {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func verifySummary() string {
|
||||
@ -80,7 +91,7 @@ func verifySummary() string {
|
||||
|
||||
func verifyHelp(out io.Writer) {
|
||||
vf := newVerifyFlags()
|
||||
out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
|
||||
_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
|
||||
vf.set.SetOutput(out)
|
||||
vf.set.PrintDefaults()
|
||||
}
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
|
||||
@ -50,34 +51,25 @@ func Test_verify(t *testing.T) {
|
||||
err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
|
||||
require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
|
||||
|
||||
// invalid ca at path
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
caFile, err := os.CreateTemp("", "verify-ca")
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(caFile.Name())
|
||||
|
||||
caFile.WriteString("-----BEGIN NOPE-----")
|
||||
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
|
||||
require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
|
||||
|
||||
// make a ca for later
|
||||
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
|
||||
ca := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "test-ca",
|
||||
NotBefore: time.Now().Add(time.Hour * -1),
|
||||
NotAfter: time.Now().Add(time.Hour * 2),
|
||||
PublicKey: caPub,
|
||||
IsCA: true,
|
||||
},
|
||||
}
|
||||
ca.Sign(cert.Curve_CURVE25519, caPriv)
|
||||
b, _ := ca.MarshalToPEM()
|
||||
ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil)
|
||||
b, _ := ca.MarshalPEM()
|
||||
caFile.Truncate(0)
|
||||
caFile.Seek(0, 0)
|
||||
caFile.Write(b)
|
||||
@ -86,38 +78,29 @@ func Test_verify(t *testing.T) {
|
||||
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.EqualError(t, err, "unable to read crt; open does_not_exist: "+NoSuchFileError)
|
||||
require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
|
||||
|
||||
// invalid crt at path
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
certFile, err := os.CreateTemp("", "verify-cert")
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(certFile.Name())
|
||||
|
||||
certFile.WriteString("-----BEGIN NOPE-----")
|
||||
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
|
||||
require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
|
||||
|
||||
// unverifiable cert at path
|
||||
_, badPriv, _ := ed25519.GenerateKey(rand.Reader)
|
||||
certPub, _ := x25519Keypair()
|
||||
signer, _ := ca.Sha256Sum()
|
||||
crt := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "test-cert",
|
||||
NotBefore: time.Now().Add(time.Hour * -1),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
PublicKey: certPub,
|
||||
IsCA: false,
|
||||
Issuer: signer,
|
||||
},
|
||||
crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
|
||||
// Slightly evil hack to modify the certificate after it was sealed to generate an invalid signature
|
||||
pub := crt.PublicKey()
|
||||
for i, _ := range pub {
|
||||
pub[i] = 0
|
||||
}
|
||||
|
||||
crt.Sign(cert.Curve_CURVE25519, badPriv)
|
||||
b, _ = crt.MarshalToPEM()
|
||||
b, _ = crt.MarshalPEM()
|
||||
certFile.Truncate(0)
|
||||
certFile.Seek(0, 0)
|
||||
certFile.Write(b)
|
||||
@ -125,11 +108,11 @@ func Test_verify(t *testing.T) {
|
||||
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.EqualError(t, err, "certificate signature did not match")
|
||||
require.ErrorIs(t, err, cert.ErrSignatureMismatch)
|
||||
|
||||
// verified cert at path
|
||||
crt.Sign(cert.Curve_CURVE25519, caPriv)
|
||||
b, _ = crt.MarshalToPEM()
|
||||
crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
|
||||
b, _ = crt.MarshalPEM()
|
||||
certFile.Truncate(0)
|
||||
certFile.Seek(0, 0)
|
||||
certFile.Write(b)
|
||||
@ -137,5 +120,5 @@ func Test_verify(t *testing.T) {
|
||||
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@ -19,18 +19,18 @@ func TestConfig_Load(t *testing.T) {
|
||||
// invalid yaml
|
||||
c := NewC(l)
|
||||
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
|
||||
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
|
||||
require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
|
||||
|
||||
// simple multi config merge
|
||||
c = NewC(l)
|
||||
os.RemoveAll(dir)
|
||||
os.Mkdir(dir, 0755)
|
||||
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
|
||||
os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644)
|
||||
assert.Nil(t, c.Load(dir))
|
||||
require.NoError(t, c.Load(dir))
|
||||
expected := map[interface{}]interface{}{
|
||||
"outer": map[interface{}]interface{}{
|
||||
"inner": "override",
|
||||
@ -38,9 +38,6 @@ func TestConfig_Load(t *testing.T) {
|
||||
"new": "hi",
|
||||
}
|
||||
assert.Equal(t, expected, c.Settings)
|
||||
|
||||
//TODO: test symlinked file
|
||||
//TODO: test symlinked directory
|
||||
}
|
||||
|
||||
func TestConfig_Get(t *testing.T) {
|
||||
@ -70,28 +67,28 @@ func TestConfig_GetBool(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := NewC(l)
|
||||
c.Settings["bool"] = true
|
||||
assert.Equal(t, true, c.GetBool("bool", false))
|
||||
assert.True(t, c.GetBool("bool", false))
|
||||
|
||||
c.Settings["bool"] = "true"
|
||||
assert.Equal(t, true, c.GetBool("bool", false))
|
||||
assert.True(t, c.GetBool("bool", false))
|
||||
|
||||
c.Settings["bool"] = false
|
||||
assert.Equal(t, false, c.GetBool("bool", true))
|
||||
assert.False(t, c.GetBool("bool", true))
|
||||
|
||||
c.Settings["bool"] = "false"
|
||||
assert.Equal(t, false, c.GetBool("bool", true))
|
||||
assert.False(t, c.GetBool("bool", true))
|
||||
|
||||
c.Settings["bool"] = "Y"
|
||||
assert.Equal(t, true, c.GetBool("bool", false))
|
||||
assert.True(t, c.GetBool("bool", false))
|
||||
|
||||
c.Settings["bool"] = "yEs"
|
||||
assert.Equal(t, true, c.GetBool("bool", false))
|
||||
assert.True(t, c.GetBool("bool", false))
|
||||
|
||||
c.Settings["bool"] = "N"
|
||||
assert.Equal(t, false, c.GetBool("bool", true))
|
||||
assert.False(t, c.GetBool("bool", true))
|
||||
|
||||
c.Settings["bool"] = "nO"
|
||||
assert.Equal(t, false, c.GetBool("bool", true))
|
||||
assert.False(t, c.GetBool("bool", true))
|
||||
}
|
||||
|
||||
func TestConfig_HasChanged(t *testing.T) {
|
||||
@ -120,11 +117,11 @@ func TestConfig_ReloadConfig(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
done := make(chan bool, 1)
|
||||
dir, err := os.MkdirTemp("", "config-test")
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
|
||||
|
||||
c := NewC(l)
|
||||
assert.Nil(t, c.Load(dir))
|
||||
require.NoError(t, c.Load(dir))
|
||||
|
||||
assert.False(t, c.HasChanged("outer.inner"))
|
||||
assert.False(t, c.HasChanged("outer"))
|
||||
|
||||
@ -183,7 +183,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
|
||||
case deleteTunnel:
|
||||
if n.hostMap.DeleteHostInfo(hostinfo) {
|
||||
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
|
||||
n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)
|
||||
n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
|
||||
}
|
||||
|
||||
case closeTunnel:
|
||||
@ -221,7 +221,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
|
||||
relayFor := oldhostinfo.relayState.CopyAllRelayFor()
|
||||
|
||||
for _, r := range relayFor {
|
||||
existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
|
||||
existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr)
|
||||
|
||||
var index uint32
|
||||
var relayFrom netip.Addr
|
||||
@ -235,11 +235,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
|
||||
index = existing.LocalIndex
|
||||
switch r.Type {
|
||||
case TerminalType:
|
||||
relayFrom = n.intf.myVpnNet.Addr()
|
||||
relayTo = existing.PeerIp
|
||||
relayFrom = n.intf.myVpnAddrs[0]
|
||||
relayTo = existing.PeerAddr
|
||||
case ForwardingType:
|
||||
relayFrom = existing.PeerIp
|
||||
relayTo = newhostinfo.vpnIp
|
||||
relayFrom = existing.PeerAddr
|
||||
relayTo = newhostinfo.vpnAddrs[0]
|
||||
default:
|
||||
// should never happen
|
||||
}
|
||||
@ -253,45 +253,64 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
|
||||
n.relayUsedLock.RUnlock()
|
||||
// The relay doesn't exist at all; create some relay state and send the request.
|
||||
var err error
|
||||
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
|
||||
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
||||
if err != nil {
|
||||
n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
|
||||
continue
|
||||
}
|
||||
switch r.Type {
|
||||
case TerminalType:
|
||||
relayFrom = n.intf.myVpnNet.Addr()
|
||||
relayTo = r.PeerIp
|
||||
relayFrom = n.intf.myVpnAddrs[0]
|
||||
relayTo = r.PeerAddr
|
||||
case ForwardingType:
|
||||
relayFrom = r.PeerIp
|
||||
relayTo = newhostinfo.vpnIp
|
||||
relayFrom = r.PeerAddr
|
||||
relayTo = newhostinfo.vpnAddrs[0]
|
||||
default:
|
||||
// should never happen
|
||||
}
|
||||
}
|
||||
|
||||
//TODO: IPV6-WORK
|
||||
relayFromB := relayFrom.As4()
|
||||
relayToB := relayTo.As4()
|
||||
|
||||
// Send a CreateRelayRequest to the peer.
|
||||
req := NebulaControl{
|
||||
Type: NebulaControl_CreateRelayRequest,
|
||||
InitiatorRelayIndex: index,
|
||||
RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]),
|
||||
RelayToIp: binary.BigEndian.Uint32(relayToB[:]),
|
||||
}
|
||||
|
||||
switch newhostinfo.GetCert().Certificate.Version() {
|
||||
case cert.Version1:
|
||||
if !relayFrom.Is4() {
|
||||
n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||
continue
|
||||
}
|
||||
|
||||
if !relayTo.Is4() {
|
||||
n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||
continue
|
||||
}
|
||||
|
||||
b := relayFrom.As4()
|
||||
req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||
b = relayTo.As4()
|
||||
req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||
case cert.Version2:
|
||||
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
||||
req.RelayToAddr = netAddrToProtoAddr(relayTo)
|
||||
default:
|
||||
newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
|
||||
continue
|
||||
}
|
||||
|
||||
msg, err := req.Marshal()
|
||||
if err != nil {
|
||||
n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
|
||||
} else {
|
||||
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
n.l.WithFields(logrus.Fields{
|
||||
"relayFrom": req.RelayFromIp,
|
||||
"relayTo": req.RelayToIp,
|
||||
"relayFrom": req.RelayFromAddr,
|
||||
"relayTo": req.RelayToAddr,
|
||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
||||
"responderRelayIndex": req.ResponderRelayIndex,
|
||||
"vpnIp": newhostinfo.vpnIp}).
|
||||
"vpnAddrs": newhostinfo.vpnAddrs}).
|
||||
Info("send CreateRelayRequest")
|
||||
}
|
||||
}
|
||||
@ -313,7 +332,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
|
||||
return closeTunnel, hostinfo, nil
|
||||
}
|
||||
|
||||
primary := n.hostMap.Hosts[hostinfo.vpnIp]
|
||||
primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
|
||||
mainHostInfo := true
|
||||
if primary != nil && primary != hostinfo {
|
||||
mainHostInfo = false
|
||||
@ -407,21 +426,24 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
|
||||
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
|
||||
// Let's sort this out.
|
||||
|
||||
if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 {
|
||||
// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
|
||||
// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
|
||||
// The remotes vpn ip is lower than mine. I will not flip.
|
||||
// Only one side should swap because if both swap then we may never resolve to a single tunnel.
|
||||
// vpn addr is static across all tunnels for this host pair so lets
|
||||
// use that to determine if we should consider swapping.
|
||||
if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
|
||||
// Their primary vpn addr is less than mine. Do not swap.
|
||||
return false
|
||||
}
|
||||
|
||||
certState := n.intf.pki.GetCertState()
|
||||
return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
|
||||
crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
||||
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
||||
// settle down.
|
||||
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
||||
}
|
||||
|
||||
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
|
||||
n.hostMap.Lock()
|
||||
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
|
||||
if n.hostMap.Hosts[current.vpnIp] == primary {
|
||||
if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
|
||||
n.hostMap.unlockedMakePrimary(current)
|
||||
}
|
||||
n.hostMap.Unlock()
|
||||
@ -436,8 +458,9 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
|
||||
return false
|
||||
}
|
||||
|
||||
valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool())
|
||||
if valid {
|
||||
caPool := n.intf.pki.GetCAPool()
|
||||
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
@ -446,9 +469,8 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
|
||||
return false
|
||||
}
|
||||
|
||||
fingerprint, _ := remoteCert.Sha256Sum()
|
||||
hostinfo.logger(n.l).WithError(err).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||
|
||||
return true
|
||||
@ -473,14 +495,17 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||
}
|
||||
|
||||
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||
certState := n.intf.pki.GetCertState()
|
||||
if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
|
||||
cs := n.intf.pki.getCertState()
|
||||
curCrt := hostinfo.ConnectionState.myCert
|
||||
myCrt := cs.getCertificate(curCrt.Version())
|
||||
if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
||||
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
||||
return
|
||||
}
|
||||
|
||||
n.l.WithField("vpnIp", hostinfo.vpnIp).
|
||||
n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "local certificate is not current").
|
||||
Info("Re-handshaking with remote")
|
||||
|
||||
n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
|
||||
n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
}
|
||||
|
||||
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
@ -15,6 +14,7 @@ import (
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestLighthouse() *LightHouse {
|
||||
@ -35,20 +35,19 @@ func newTestLighthouse() *LightHouse {
|
||||
func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
|
||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
||||
vpnIp := netip.MustParseAddr("172.1.1.2")
|
||||
preferredRanges := []netip.Prefix{localrange}
|
||||
|
||||
// Very incomplete mock objects
|
||||
hostMap := newHostMap(l, vpncidr)
|
||||
hostMap := newHostMap(l)
|
||||
hostMap.preferredRanges.Store(&preferredRanges)
|
||||
|
||||
cs := &CertState{
|
||||
RawCertificate: []byte{},
|
||||
PrivateKey: []byte{},
|
||||
Certificate: &cert.NebulaCertificate{},
|
||||
RawCertificateNoKey: []byte{},
|
||||
defaultVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1},
|
||||
v1HandshakeBytes: []byte{},
|
||||
}
|
||||
|
||||
lh := newTestLighthouse()
|
||||
@ -75,12 +74,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
|
||||
// Add an ip we have established a connection w/ to hostmap
|
||||
hostinfo := &HostInfo{
|
||||
vpnIp: vpnIp,
|
||||
vpnAddrs: []netip.Addr{vpnIp},
|
||||
localIndexId: 1099,
|
||||
remoteIndexId: 9901,
|
||||
}
|
||||
hostinfo.ConnectionState = &ConnectionState{
|
||||
myCert: &cert.NebulaCertificate{},
|
||||
myCert: &dummyCert{version: cert.Version1},
|
||||
H: &noise.HandshakeState{},
|
||||
}
|
||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||
@ -89,7 +88,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
nc.Out(hostinfo.localIndexId)
|
||||
nc.In(hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.out, hostinfo.localIndexId)
|
||||
|
||||
@ -106,32 +105,31 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||
|
||||
// Do a final traffic check tick, the host should now be removed
|
||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
|
||||
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||
}
|
||||
|
||||
func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
|
||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
||||
vpnIp := netip.MustParseAddr("172.1.1.2")
|
||||
preferredRanges := []netip.Prefix{localrange}
|
||||
|
||||
// Very incomplete mock objects
|
||||
hostMap := newHostMap(l, vpncidr)
|
||||
hostMap := newHostMap(l)
|
||||
hostMap.preferredRanges.Store(&preferredRanges)
|
||||
|
||||
cs := &CertState{
|
||||
RawCertificate: []byte{},
|
||||
PrivateKey: []byte{},
|
||||
Certificate: &cert.NebulaCertificate{},
|
||||
RawCertificateNoKey: []byte{},
|
||||
defaultVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1},
|
||||
v1HandshakeBytes: []byte{},
|
||||
}
|
||||
|
||||
lh := newTestLighthouse()
|
||||
@ -158,12 +156,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
|
||||
// Add an ip we have established a connection w/ to hostmap
|
||||
hostinfo := &HostInfo{
|
||||
vpnIp: vpnIp,
|
||||
vpnAddrs: []netip.Addr{vpnIp},
|
||||
localIndexId: 1099,
|
||||
remoteIndexId: 9901,
|
||||
}
|
||||
hostinfo.ConnectionState = &ConnectionState{
|
||||
myCert: &cert.NebulaCertificate{},
|
||||
myCert: &dummyCert{version: cert.Version1},
|
||||
H: &noise.HandshakeState{},
|
||||
}
|
||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||
@ -171,8 +169,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
// We saw traffic out to vpnIp
|
||||
nc.Out(hostinfo.localIndexId)
|
||||
nc.In(hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp)
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
|
||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||
|
||||
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
||||
@ -188,7 +186,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||
|
||||
// We saw traffic, should no longer be pending deletion
|
||||
nc.In(hostinfo.localIndexId)
|
||||
@ -197,7 +195,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||
}
|
||||
|
||||
// Check if we can disconnect the peer.
|
||||
@ -206,55 +204,48 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||
now := time.Now()
|
||||
l := test.NewLogger()
|
||||
ipNet := net.IPNet{
|
||||
IP: net.IPv4(172, 1, 1, 2),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
|
||||
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
|
||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
||||
vpnIp := netip.MustParseAddr("172.1.1.2")
|
||||
preferredRanges := []netip.Prefix{localrange}
|
||||
hostMap := newHostMap(l, vpncidr)
|
||||
hostMap := newHostMap(l)
|
||||
hostMap.preferredRanges.Store(&preferredRanges)
|
||||
|
||||
// Generate keys for CA and peer's cert.
|
||||
pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
|
||||
caCert := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "ca",
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(1 * time.Hour),
|
||||
IsCA: true,
|
||||
PublicKey: pubCA,
|
||||
},
|
||||
tbs := &cert.TBSCertificate{
|
||||
Version: 1,
|
||||
Name: "ca",
|
||||
IsCA: true,
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(1 * time.Hour),
|
||||
PublicKey: pubCA,
|
||||
}
|
||||
|
||||
assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA))
|
||||
ncp := &cert.NebulaCAPool{
|
||||
CAs: cert.NewCAPool().CAs,
|
||||
}
|
||||
ncp.CAs["ca"] = &caCert
|
||||
caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA)
|
||||
require.NoError(t, err)
|
||||
ncp := cert.NewCAPool()
|
||||
require.NoError(t, ncp.AddCA(caCert))
|
||||
|
||||
pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
|
||||
peerCert := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "host",
|
||||
Ips: []*net.IPNet{&ipNet},
|
||||
Subnets: []*net.IPNet{},
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(60 * time.Second),
|
||||
PublicKey: pubCrt,
|
||||
IsCA: false,
|
||||
Issuer: "ca",
|
||||
},
|
||||
tbs = &cert.TBSCertificate{
|
||||
Version: 1,
|
||||
Name: "host",
|
||||
Networks: []netip.Prefix{vpncidr},
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(60 * time.Second),
|
||||
PublicKey: pubCrt,
|
||||
}
|
||||
assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA))
|
||||
peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA)
|
||||
require.NoError(t, err)
|
||||
|
||||
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
|
||||
|
||||
cs := &CertState{
|
||||
RawCertificate: []byte{},
|
||||
PrivateKey: []byte{},
|
||||
Certificate: &cert.NebulaCertificate{},
|
||||
RawCertificateNoKey: []byte{},
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{},
|
||||
v1HandshakeBytes: []byte{},
|
||||
}
|
||||
|
||||
lh := newTestLighthouse()
|
||||
@ -280,10 +271,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||
ifce.connectionManager = nc
|
||||
|
||||
hostinfo := &HostInfo{
|
||||
vpnIp: vpnIp,
|
||||
vpnAddrs: []netip.Addr{vpnIp},
|
||||
ConnectionState: &ConnectionState{
|
||||
myCert: &cert.NebulaCertificate{},
|
||||
peerCert: &peerCert,
|
||||
myCert: &dummyCert{},
|
||||
peerCert: cachedPeerCert,
|
||||
H: &noise.HandshakeState{},
|
||||
},
|
||||
}
|
||||
@ -303,3 +294,114 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||
invalid = nc.isInvalidCertificate(nextTick, hostinfo)
|
||||
assert.True(t, invalid)
|
||||
}
|
||||
|
||||
type dummyCert struct {
|
||||
version cert.Version
|
||||
curve cert.Curve
|
||||
groups []string
|
||||
isCa bool
|
||||
issuer string
|
||||
name string
|
||||
networks []netip.Prefix
|
||||
notAfter time.Time
|
||||
notBefore time.Time
|
||||
publicKey []byte
|
||||
signature []byte
|
||||
unsafeNetworks []netip.Prefix
|
||||
}
|
||||
|
||||
func (d *dummyCert) Version() cert.Version {
|
||||
return d.version
|
||||
}
|
||||
|
||||
func (d *dummyCert) Curve() cert.Curve {
|
||||
return d.curve
|
||||
}
|
||||
|
||||
func (d *dummyCert) Groups() []string {
|
||||
return d.groups
|
||||
}
|
||||
|
||||
func (d *dummyCert) IsCA() bool {
|
||||
return d.isCa
|
||||
}
|
||||
|
||||
func (d *dummyCert) Issuer() string {
|
||||
return d.issuer
|
||||
}
|
||||
|
||||
func (d *dummyCert) Name() string {
|
||||
return d.name
|
||||
}
|
||||
|
||||
func (d *dummyCert) Networks() []netip.Prefix {
|
||||
return d.networks
|
||||
}
|
||||
|
||||
func (d *dummyCert) NotAfter() time.Time {
|
||||
return d.notAfter
|
||||
}
|
||||
|
||||
func (d *dummyCert) NotBefore() time.Time {
|
||||
return d.notBefore
|
||||
}
|
||||
|
||||
func (d *dummyCert) PublicKey() []byte {
|
||||
return d.publicKey
|
||||
}
|
||||
|
||||
func (d *dummyCert) Signature() []byte {
|
||||
return d.signature
|
||||
}
|
||||
|
||||
func (d *dummyCert) UnsafeNetworks() []netip.Prefix {
|
||||
return d.unsafeNetworks
|
||||
}
|
||||
|
||||
func (d *dummyCert) MarshalForHandshakes() ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d *dummyCert) Sign(curve cert.Curve, key []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dummyCert) CheckSignature(key []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *dummyCert) Expired(t time.Time) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *dummyCert) CheckRootConstraints(signer cert.Certificate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dummyCert) VerifyPrivateKey(curve cert.Curve, key []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dummyCert) String() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (d *dummyCert) Marshal() ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d *dummyCert) MarshalPEM() ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d *dummyCert) Fingerprint() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (d *dummyCert) MarshalJSON() ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d *dummyCert) Copy() cert.Certificate {
|
||||
return d
|
||||
}
|
||||
|
||||
@ -3,6 +3,7 @@ package nebula
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
@ -18,54 +19,54 @@ type ConnectionState struct {
|
||||
eKey *NebulaCipherState
|
||||
dKey *NebulaCipherState
|
||||
H *noise.HandshakeState
|
||||
myCert *cert.NebulaCertificate
|
||||
peerCert *cert.NebulaCertificate
|
||||
myCert cert.Certificate
|
||||
peerCert *cert.CachedCertificate
|
||||
initiator bool
|
||||
messageCounter atomic.Uint64
|
||||
window *Bits
|
||||
writeLock sync.Mutex
|
||||
}
|
||||
|
||||
func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
|
||||
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
|
||||
var dhFunc noise.DHFunc
|
||||
switch certState.Certificate.Details.Curve {
|
||||
switch crt.Curve() {
|
||||
case cert.Curve_CURVE25519:
|
||||
dhFunc = noise.DH25519
|
||||
case cert.Curve_P256:
|
||||
if certState.Certificate.Pkcs11Backed {
|
||||
if cs.pkcs11Backed {
|
||||
dhFunc = noiseutil.DHP256PKCS11
|
||||
} else {
|
||||
dhFunc = noiseutil.DHP256
|
||||
}
|
||||
default:
|
||||
l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
|
||||
return nil
|
||||
return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
|
||||
}
|
||||
|
||||
var cs noise.CipherSuite
|
||||
if cipher == "chachapoly" {
|
||||
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
var ncs noise.CipherSuite
|
||||
if cs.cipher == "chachapoly" {
|
||||
ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
} else {
|
||||
cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
|
||||
ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
|
||||
}
|
||||
|
||||
static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
|
||||
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
|
||||
|
||||
b := NewBits(ReplayWindow)
|
||||
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
|
||||
// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
|
||||
b.Update(l, 0)
|
||||
|
||||
hs, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: cs,
|
||||
Random: rand.Reader,
|
||||
Pattern: pattern,
|
||||
Initiator: initiator,
|
||||
StaticKeypair: static,
|
||||
PresharedKey: psk,
|
||||
PresharedKeyPlacement: pskStage,
|
||||
CipherSuite: ncs,
|
||||
Random: rand.Reader,
|
||||
Pattern: pattern,
|
||||
Initiator: initiator,
|
||||
StaticKeypair: static,
|
||||
//NOTE: These should come from CertState (pki.go) when we finally implement it
|
||||
PresharedKey: []byte{},
|
||||
PresharedKeyPlacement: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, fmt.Errorf("NewConnectionState: %s", err)
|
||||
}
|
||||
|
||||
// The queue and ready params prevent a counter race that would happen when
|
||||
@ -74,12 +75,12 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
|
||||
H: hs,
|
||||
initiator: initiator,
|
||||
window: b,
|
||||
myCert: certState.Certificate,
|
||||
myCert: crt,
|
||||
}
|
||||
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
|
||||
ci.messageCounter.Add(2)
|
||||
|
||||
return ci
|
||||
return ci, nil
|
||||
}
|
||||
|
||||
func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
|
||||
@ -89,3 +90,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
|
||||
"message_counter": cs.messageCounter.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
func (cs *ConnectionState) Curve() cert.Curve {
|
||||
return cs.myCert.Curve()
|
||||
}
|
||||
|
||||
73
control.go
73
control.go
@ -19,9 +19,9 @@ import (
|
||||
type controlEach func(h *HostInfo)
|
||||
|
||||
type controlHostLister interface {
|
||||
QueryVpnIp(vpnIp netip.Addr) *HostInfo
|
||||
QueryVpnAddr(vpnAddr netip.Addr) *HostInfo
|
||||
ForEachIndex(each controlEach)
|
||||
ForEachVpnIp(each controlEach)
|
||||
ForEachVpnAddr(each controlEach)
|
||||
GetPreferredRanges() []netip.Prefix
|
||||
}
|
||||
|
||||
@ -37,15 +37,15 @@ type Control struct {
|
||||
}
|
||||
|
||||
type ControlHostInfo struct {
|
||||
VpnIp netip.Addr `json:"vpnIp"`
|
||||
LocalIndex uint32 `json:"localIndex"`
|
||||
RemoteIndex uint32 `json:"remoteIndex"`
|
||||
RemoteAddrs []netip.AddrPort `json:"remoteAddrs"`
|
||||
Cert *cert.NebulaCertificate `json:"cert"`
|
||||
MessageCounter uint64 `json:"messageCounter"`
|
||||
CurrentRemote netip.AddrPort `json:"currentRemote"`
|
||||
CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"`
|
||||
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
||||
VpnAddrs []netip.Addr `json:"vpnAddrs"`
|
||||
LocalIndex uint32 `json:"localIndex"`
|
||||
RemoteIndex uint32 `json:"remoteIndex"`
|
||||
RemoteAddrs []netip.AddrPort `json:"remoteAddrs"`
|
||||
Cert cert.Certificate `json:"cert"`
|
||||
MessageCounter uint64 `json:"messageCounter"`
|
||||
CurrentRemote netip.AddrPort `json:"currentRemote"`
|
||||
CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"`
|
||||
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
||||
}
|
||||
|
||||
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
||||
@ -130,15 +130,18 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
|
||||
}
|
||||
|
||||
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
|
||||
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) *cert.NebulaCertificate {
|
||||
if c.f.myVpnNet.Addr() == vpnIp {
|
||||
return c.f.pki.GetCertState().Certificate
|
||||
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
|
||||
_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
|
||||
if found {
|
||||
// Only returning the default certificate since its impossible
|
||||
// for any other host but ourselves to have more than 1
|
||||
return c.f.pki.getCertState().GetDefaultCertificate().Copy()
|
||||
}
|
||||
hi := c.f.hostMap.QueryVpnIp(vpnIp)
|
||||
hi := c.f.hostMap.QueryVpnAddr(vpnIp)
|
||||
if hi == nil {
|
||||
return nil
|
||||
}
|
||||
return hi.GetCert()
|
||||
return hi.GetCert().Certificate.Copy()
|
||||
}
|
||||
|
||||
// CreateTunnel creates a new tunnel to the given vpn ip.
|
||||
@ -148,7 +151,7 @@ func (c *Control) CreateTunnel(vpnIp netip.Addr) {
|
||||
|
||||
// PrintTunnel creates a new tunnel to the given vpn ip.
|
||||
func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
|
||||
hi := c.f.hostMap.QueryVpnIp(vpnIp)
|
||||
hi := c.f.hostMap.QueryVpnAddr(vpnIp)
|
||||
if hi == nil {
|
||||
return nil
|
||||
}
|
||||
@ -165,9 +168,9 @@ func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap {
|
||||
return hi.CopyCache()
|
||||
}
|
||||
|
||||
// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
|
||||
// GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found
|
||||
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
|
||||
func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo {
|
||||
func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo {
|
||||
var hl controlHostLister
|
||||
if pending {
|
||||
hl = c.f.handshakeManager
|
||||
@ -175,7 +178,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
|
||||
hl = c.f.hostMap
|
||||
}
|
||||
|
||||
h := hl.QueryVpnIp(vpnIp)
|
||||
h := hl.QueryVpnAddr(vpnAddr)
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
@ -187,7 +190,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
|
||||
// SetRemoteForTunnel forces a tunnel to use a specific remote
|
||||
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
|
||||
func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
|
||||
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
|
||||
hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
|
||||
if hostInfo == nil {
|
||||
return nil
|
||||
}
|
||||
@ -200,7 +203,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *Con
|
||||
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
|
||||
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
|
||||
func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
|
||||
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
|
||||
hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
|
||||
if hostInfo == nil {
|
||||
return false
|
||||
}
|
||||
@ -224,19 +227,14 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
|
||||
// CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels
|
||||
// the int returned is a count of tunnels closed
|
||||
func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
|
||||
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
|
||||
lighthouses := c.f.lightHouse.GetLighthouses()
|
||||
|
||||
shutdown := func(h *HostInfo) {
|
||||
if excludeLighthouses {
|
||||
if _, ok := lighthouses[h.vpnIp]; ok {
|
||||
return
|
||||
}
|
||||
if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) {
|
||||
return
|
||||
}
|
||||
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||
c.f.closeTunnel(h)
|
||||
|
||||
c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote).
|
||||
c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
|
||||
Debug("Sending close tunnel message")
|
||||
closed++
|
||||
}
|
||||
@ -246,7 +244,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
|
||||
// Grab the hostMap lock to access the Relays map
|
||||
c.f.hostMap.Lock()
|
||||
for _, relayingHost := range c.f.hostMap.Relays {
|
||||
relayingHosts[relayingHost.vpnIp] = relayingHost
|
||||
relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost
|
||||
}
|
||||
c.f.hostMap.Unlock()
|
||||
|
||||
@ -254,7 +252,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
|
||||
// Grab the hostMap lock to access the Hosts map
|
||||
c.f.hostMap.Lock()
|
||||
for _, relayHost := range c.f.hostMap.Indexes {
|
||||
if _, ok := relayingHosts[relayHost.vpnIp]; !ok {
|
||||
if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok {
|
||||
hostInfos = append(hostInfos, relayHost)
|
||||
}
|
||||
}
|
||||
@ -274,9 +272,8 @@ func (c *Control) Device() overlay.Device {
|
||||
}
|
||||
|
||||
func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
|
||||
|
||||
chi := ControlHostInfo{
|
||||
VpnIp: h.vpnIp,
|
||||
VpnAddrs: make([]netip.Addr, len(h.vpnAddrs)),
|
||||
LocalIndex: h.localIndexId,
|
||||
RemoteIndex: h.remoteIndexId,
|
||||
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
|
||||
@ -285,12 +282,16 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
|
||||
CurrentRemote: h.remote,
|
||||
}
|
||||
|
||||
for i, a := range h.vpnAddrs {
|
||||
chi.VpnAddrs[i] = a
|
||||
}
|
||||
|
||||
if h.ConnectionState != nil {
|
||||
chi.MessageCounter = h.ConnectionState.messageCounter.Load()
|
||||
}
|
||||
|
||||
if c := h.GetCert(); c != nil {
|
||||
chi.Cert = c.Copy()
|
||||
chi.Cert = c.Certificate.Copy()
|
||||
}
|
||||
|
||||
return chi
|
||||
@ -299,7 +300,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
|
||||
func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
|
||||
hosts := make([]ControlHostInfo, 0)
|
||||
pr := hl.GetPreferredRanges()
|
||||
hl.ForEachVpnIp(func(hostinfo *HostInfo) {
|
||||
hl.ForEachVpnAddr(func(hostinfo *HostInfo) {
|
||||
hosts = append(hosts, copyHostInfo(hostinfo, pr))
|
||||
})
|
||||
return hosts
|
||||
|
||||
@ -5,7 +5,6 @@ import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
@ -14,10 +13,13 @@ import (
|
||||
)
|
||||
|
||||
func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
||||
//TODO: CERT-V2 with multiple certificate versions we have a problem with this test
|
||||
// Some certs versions have different characteristics and each version implements their own Copy() func
|
||||
// which means this is not a good place to test for exposing memory
|
||||
l := test.NewLogger()
|
||||
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
||||
// To properly ensure we are not exposing core memory to the caller
|
||||
hm := newHostMap(l, netip.Prefix{})
|
||||
hm := newHostMap(l)
|
||||
hm.preferredRanges.Store(&[]netip.Prefix{})
|
||||
|
||||
remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
|
||||
@ -33,42 +35,27 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
|
||||
crt := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "test",
|
||||
Ips: []*net.IPNet{&ipNet},
|
||||
Subnets: []*net.IPNet{},
|
||||
Groups: []string{"default-group"},
|
||||
NotBefore: time.Unix(1, 0),
|
||||
NotAfter: time.Unix(2, 0),
|
||||
PublicKey: []byte{5, 6, 7, 8},
|
||||
IsCA: false,
|
||||
Issuer: "the-issuer",
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||
},
|
||||
Signature: []byte{1, 2, 1, 2, 1, 3},
|
||||
}
|
||||
|
||||
remotes := NewRemoteList(nil)
|
||||
remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port()))
|
||||
remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port()))
|
||||
remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil)
|
||||
remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port()))
|
||||
remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port()))
|
||||
|
||||
vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
|
||||
assert.True(t, ok)
|
||||
|
||||
crt := &dummyCert{}
|
||||
hm.unlockedAddHostInfo(&HostInfo{
|
||||
remote: remote1,
|
||||
remotes: remotes,
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: crt,
|
||||
peerCert: &cert.CachedCertificate{Certificate: crt},
|
||||
},
|
||||
remoteIndexId: 200,
|
||||
localIndexId: 201,
|
||||
vpnIp: vpnIp,
|
||||
vpnAddrs: []netip.Addr{vpnIp},
|
||||
relayState: RelayState{
|
||||
relays: map[netip.Addr]struct{}{},
|
||||
relayForByIp: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
relays: map[netip.Addr]struct{}{},
|
||||
relayForByAddr: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
},
|
||||
}, &Interface{})
|
||||
|
||||
@ -83,11 +70,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
||||
},
|
||||
remoteIndexId: 200,
|
||||
localIndexId: 201,
|
||||
vpnIp: vpnIp2,
|
||||
vpnAddrs: []netip.Addr{vpnIp2},
|
||||
relayState: RelayState{
|
||||
relays: map[netip.Addr]struct{}{},
|
||||
relayForByIp: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
relays: map[netip.Addr]struct{}{},
|
||||
relayForByAddr: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
},
|
||||
}, &Interface{})
|
||||
|
||||
@ -98,10 +85,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
||||
l: logrus.New(),
|
||||
}
|
||||
|
||||
thi := c.GetHostInfoByVpnIp(vpnIp, false)
|
||||
thi := c.GetHostInfoByVpnAddr(vpnIp, false)
|
||||
|
||||
expectedInfo := ControlHostInfo{
|
||||
VpnIp: vpnIp,
|
||||
VpnAddrs: []netip.Addr{vpnIp},
|
||||
LocalIndex: 201,
|
||||
RemoteIndex: 200,
|
||||
RemoteAddrs: []netip.AddrPort{remote2, remote1},
|
||||
@ -113,14 +100,13 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
||||
}
|
||||
|
||||
// Make sure we don't have any unexpected fields
|
||||
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
|
||||
assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
|
||||
assert.EqualValues(t, &expectedInfo, thi)
|
||||
//TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here
|
||||
//test.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
||||
test.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
||||
|
||||
// Make sure we don't panic if the host info doesn't have a cert yet
|
||||
assert.NotPanics(t, func() {
|
||||
thi = c.GetHostInfoByVpnIp(vpnIp2, false)
|
||||
thi = c.GetHostInfoByVpnAddr(vpnIp2, false)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -6,8 +6,6 @@ package nebula
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/slackhq/nebula/header"
|
||||
@ -51,15 +49,15 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
|
||||
// This is necessary if you did not configure static hosts or are not running a lighthouse
|
||||
func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
|
||||
c.f.lightHouse.Lock()
|
||||
remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
|
||||
remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
|
||||
remoteList.Lock()
|
||||
defer remoteList.Unlock()
|
||||
c.f.lightHouse.Unlock()
|
||||
|
||||
if toAddr.Addr().Is4() {
|
||||
remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
|
||||
remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port()))
|
||||
} else {
|
||||
remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
|
||||
remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -67,12 +65,12 @@ func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort)
|
||||
// This is necessary to inform an initiator of possible relays for communicating with a responder
|
||||
func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
|
||||
c.f.lightHouse.Lock()
|
||||
remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
|
||||
remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
|
||||
remoteList.Lock()
|
||||
defer remoteList.Unlock()
|
||||
c.f.lightHouse.Unlock()
|
||||
|
||||
remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps)
|
||||
remoteList.unlockedSetRelay(vpnIp, relayVpnIps)
|
||||
}
|
||||
|
||||
// GetFromTun will pull a packet off the tun side of nebula
|
||||
@ -99,21 +97,42 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
|
||||
}
|
||||
|
||||
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
|
||||
func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) {
|
||||
//TODO: IPV6-WORK
|
||||
ip := layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(),
|
||||
DstIP: toIp.Unmap().AsSlice(),
|
||||
func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) {
|
||||
serialize := make([]gopacket.SerializableLayer, 0)
|
||||
var netLayer gopacket.NetworkLayer
|
||||
if toAddr.Is6() {
|
||||
if !fromAddr.Is6() {
|
||||
panic("Cant send ipv6 to ipv4")
|
||||
}
|
||||
ip := &layers.IPv6{
|
||||
Version: 6,
|
||||
NextHeader: layers.IPProtocolUDP,
|
||||
SrcIP: fromAddr.Unmap().AsSlice(),
|
||||
DstIP: toAddr.Unmap().AsSlice(),
|
||||
}
|
||||
serialize = append(serialize, ip)
|
||||
netLayer = ip
|
||||
} else {
|
||||
if !fromAddr.Is4() {
|
||||
panic("Cant send ipv4 to ipv6")
|
||||
}
|
||||
|
||||
ip := &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
SrcIP: fromAddr.Unmap().AsSlice(),
|
||||
DstIP: toAddr.Unmap().AsSlice(),
|
||||
}
|
||||
serialize = append(serialize, ip)
|
||||
netLayer = ip
|
||||
}
|
||||
|
||||
udp := layers.UDP{
|
||||
SrcPort: layers.UDPPort(fromPort),
|
||||
DstPort: layers.UDPPort(toPort),
|
||||
}
|
||||
err := udp.SetNetworkLayerForChecksum(&ip)
|
||||
err := udp.SetNetworkLayerForChecksum(netLayer)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -123,7 +142,9 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
|
||||
|
||||
serialize = append(serialize, &udp, gopacket.Payload(data))
|
||||
err = gopacket.SerializeLayers(buffer, opt, serialize...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -131,8 +152,8 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
|
||||
c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
|
||||
}
|
||||
|
||||
func (c *Control) GetVpnIp() netip.Addr {
|
||||
return c.f.myVpnNet.Addr()
|
||||
func (c *Control) GetVpnAddrs() []netip.Addr {
|
||||
return c.f.myVpnAddrs
|
||||
}
|
||||
|
||||
func (c *Control) GetUDPAddr() netip.AddrPort {
|
||||
@ -140,7 +161,7 @@ func (c *Control) GetUDPAddr() netip.AddrPort {
|
||||
}
|
||||
|
||||
func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
|
||||
hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp)
|
||||
hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp)
|
||||
if hostinfo == nil {
|
||||
return false
|
||||
}
|
||||
@ -153,8 +174,8 @@ func (c *Control) GetHostmap() *HostMap {
|
||||
return c.f.hostMap
|
||||
}
|
||||
|
||||
func (c *Control) GetCert() *cert.NebulaCertificate {
|
||||
return c.f.pki.GetCertState().Certificate
|
||||
func (c *Control) GetCertState() *CertState {
|
||||
return c.f.pki.getCertState()
|
||||
}
|
||||
|
||||
func (c *Control) ReHandshake(vpnIp netip.Addr) {
|
||||
|
||||
118
dns_server.go
118
dns_server.go
@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
@ -21,24 +22,39 @@ var dnsAddr string
|
||||
|
||||
type dnsRecords struct {
|
||||
sync.RWMutex
|
||||
dnsMap map[string]string
|
||||
hostMap *HostMap
|
||||
l *logrus.Logger
|
||||
dnsMap4 map[string]netip.Addr
|
||||
dnsMap6 map[string]netip.Addr
|
||||
hostMap *HostMap
|
||||
myVpnAddrsTable *bart.Table[struct{}]
|
||||
}
|
||||
|
||||
func newDnsRecords(hostMap *HostMap) *dnsRecords {
|
||||
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
|
||||
return &dnsRecords{
|
||||
dnsMap: make(map[string]string),
|
||||
hostMap: hostMap,
|
||||
l: l,
|
||||
dnsMap4: make(map[string]netip.Addr),
|
||||
dnsMap6: make(map[string]netip.Addr),
|
||||
hostMap: hostMap,
|
||||
myVpnAddrsTable: cs.myVpnAddrsTable,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dnsRecords) Query(data string) string {
|
||||
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
||||
data = strings.ToLower(data)
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
|
||||
return r
|
||||
switch q {
|
||||
case dns.TypeA:
|
||||
if r, ok := d.dnsMap4[data]; ok {
|
||||
return r
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
if r, ok := d.dnsMap6[data]; ok {
|
||||
return r
|
||||
}
|
||||
}
|
||||
return ""
|
||||
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (d *dnsRecords) QueryCert(data string) string {
|
||||
@ -47,7 +63,7 @@ func (d *dnsRecords) QueryCert(data string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
hostinfo := d.hostMap.QueryVpnIp(ip)
|
||||
hostinfo := d.hostMap.QueryVpnAddr(ip)
|
||||
if hostinfo == nil {
|
||||
return ""
|
||||
}
|
||||
@ -57,43 +73,69 @@ func (d *dnsRecords) QueryCert(data string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
cert := q.Details
|
||||
c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
|
||||
return c
|
||||
b, err := q.Certificate.MarshalJSON()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (d *dnsRecords) Add(host, data string) {
|
||||
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
|
||||
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
|
||||
host = strings.ToLower(host)
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
d.dnsMap[strings.ToLower(host)] = data
|
||||
haveV4 := false
|
||||
haveV6 := false
|
||||
for _, addr := range addresses {
|
||||
if addr.Is4() && !haveV4 {
|
||||
d.dnsMap4[host] = addr
|
||||
haveV4 = true
|
||||
} else if addr.Is6() && !haveV6 {
|
||||
d.dnsMap6[host] = addr
|
||||
haveV6 = true
|
||||
}
|
||||
if haveV4 && haveV6 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
|
||||
func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
|
||||
a, _, _ := net.SplitHostPort(addr)
|
||||
b, err := netip.ParseAddr(a)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if b.IsLoopback() {
|
||||
return true
|
||||
}
|
||||
|
||||
_, found := d.myVpnAddrsTable.Lookup(b)
|
||||
return found //if we found it in this table, it's good
|
||||
}
|
||||
|
||||
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||
for _, q := range m.Question {
|
||||
switch q.Qtype {
|
||||
case dns.TypeA:
|
||||
l.Debugf("Query for A %s", q.Name)
|
||||
ip := dnsR.Query(q.Name)
|
||||
if ip != "" {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
qType := dns.TypeToString[q.Qtype]
|
||||
d.l.Debugf("Query for %s %s", qType, q.Name)
|
||||
ip := d.Query(q.Qtype, q.Name)
|
||||
if ip.IsValid() {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
|
||||
if err == nil {
|
||||
m.Answer = append(m.Answer, rr)
|
||||
}
|
||||
}
|
||||
case dns.TypeTXT:
|
||||
a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
|
||||
b, err := netip.ParseAddr(a)
|
||||
if err != nil {
|
||||
// We only answer these queries from nebula nodes or localhost
|
||||
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
|
||||
return
|
||||
}
|
||||
|
||||
// We don't answer these queries from non nebula nodes or localhost
|
||||
//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
|
||||
if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
|
||||
return
|
||||
}
|
||||
l.Debugf("Query for TXT %s", q.Name)
|
||||
ip := dnsR.QueryCert(q.Name)
|
||||
d.l.Debugf("Query for TXT %s", q.Name)
|
||||
ip := d.QueryCert(q.Name)
|
||||
if ip != "" {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
||||
if err == nil {
|
||||
@ -108,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
|
||||
}
|
||||
}
|
||||
|
||||
func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
|
||||
func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Compress = false
|
||||
|
||||
switch r.Opcode {
|
||||
case dns.OpcodeQuery:
|
||||
parseQuery(l, m, w)
|
||||
d.parseQuery(m, w)
|
||||
}
|
||||
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
|
||||
dnsR = newDnsRecords(hostMap)
|
||||
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
||||
dnsR = newDnsRecords(l, cs, hostMap)
|
||||
|
||||
// attach request handler func
|
||||
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
handleDnsRequest(l, w, r)
|
||||
})
|
||||
dns.HandleFunc(".", dnsR.handleDnsRequest)
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
reloadDns(l, c)
|
||||
|
||||
@ -1,23 +1,38 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParsequery(t *testing.T) {
|
||||
//TODO: This test is basically pointless
|
||||
l := logrus.New()
|
||||
hostMap := &HostMap{}
|
||||
ds := newDnsRecords(hostMap)
|
||||
ds.Add("test.com.com", "1.2.3.4")
|
||||
ds := newDnsRecords(l, &CertState{}, hostMap)
|
||||
addrs := []netip.Addr{
|
||||
netip.MustParseAddr("1.2.3.4"),
|
||||
netip.MustParseAddr("1.2.3.5"),
|
||||
netip.MustParseAddr("fd01::24"),
|
||||
netip.MustParseAddr("fd01::25"),
|
||||
}
|
||||
ds.Add("test.com.com", addrs)
|
||||
|
||||
m := new(dns.Msg)
|
||||
m := &dns.Msg{}
|
||||
m.SetQuestion("test.com.com", dns.TypeA)
|
||||
ds.parseQuery(m, nil)
|
||||
assert.NotNil(t, m.Answer)
|
||||
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
|
||||
|
||||
//parseQuery(m)
|
||||
m = &dns.Msg{}
|
||||
m.SetQuestion("test.com.com", dns.TypeAAAA)
|
||||
ds.parseQuery(m, nil)
|
||||
assert.NotNil(t, m.Answer)
|
||||
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
|
||||
}
|
||||
|
||||
func Test_getDnsServerAddr(t *testing.T) {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
125
e2e/helpers.go
125
e2e/helpers.go
@ -1,125 +0,0 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
|
||||
// NewTestCaCert will generate a CA cert
|
||||
func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if before.IsZero() {
|
||||
before = time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
}
|
||||
if after.IsZero() {
|
||||
after = time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
}
|
||||
|
||||
nc := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "test ca",
|
||||
NotBefore: time.Unix(before.Unix(), 0),
|
||||
NotAfter: time.Unix(after.Unix(), 0),
|
||||
PublicKey: pub,
|
||||
IsCA: true,
|
||||
InvertedGroups: make(map[string]struct{}),
|
||||
},
|
||||
}
|
||||
|
||||
if len(ips) > 0 {
|
||||
nc.Details.Ips = make([]*net.IPNet, len(ips))
|
||||
for i, ip := range ips {
|
||||
nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
|
||||
}
|
||||
}
|
||||
|
||||
if len(subnets) > 0 {
|
||||
nc.Details.Subnets = make([]*net.IPNet, len(subnets))
|
||||
for i, ip := range subnets {
|
||||
nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
|
||||
}
|
||||
}
|
||||
|
||||
if len(groups) > 0 {
|
||||
nc.Details.Groups = groups
|
||||
}
|
||||
|
||||
err = nc.Sign(cert.Curve_CURVE25519, priv)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pem, err := nc.MarshalToPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nc, pub, priv, pem
|
||||
}
|
||||
|
||||
// NewTestCert will generate a signed certificate with the provided details.
|
||||
// Expiry times are defaulted if you do not pass them in
|
||||
func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
|
||||
issuer, err := ca.Sha256Sum()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if before.IsZero() {
|
||||
before = time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
}
|
||||
|
||||
if after.IsZero() {
|
||||
after = time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
}
|
||||
|
||||
pub, rawPriv := x25519Keypair()
|
||||
ipb := ip.Addr().AsSlice()
|
||||
nc := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: name,
|
||||
Ips: []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}},
|
||||
//Subnets: subnets,
|
||||
Groups: groups,
|
||||
NotBefore: time.Unix(before.Unix(), 0),
|
||||
NotAfter: time.Unix(after.Unix(), 0),
|
||||
PublicKey: pub,
|
||||
IsCA: false,
|
||||
Issuer: issuer,
|
||||
InvertedGroups: make(map[string]struct{}),
|
||||
},
|
||||
}
|
||||
|
||||
err = nc.Sign(ca.Details.Curve, key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pem, err := nc.MarshalToPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem
|
||||
}
|
||||
|
||||
func x25519Keypair() ([]byte, []byte) {
|
||||
privkey := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return pubkey, privkey
|
||||
}
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -17,6 +18,7 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/e2e/router"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -26,27 +28,37 @@ import (
|
||||
type m map[string]interface{}
|
||||
|
||||
// newSimpleServer creates a nebula instance with many assumptions
|
||||
func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) {
|
||||
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
l := NewTestLogger()
|
||||
|
||||
vpnIpNet, err := netip.ParsePrefix(sVpnIpNet)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
var vpnNetworks []netip.Prefix
|
||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
vpnNetworks = append(vpnNetworks, vpnIpNet)
|
||||
}
|
||||
|
||||
if len(vpnNetworks) == 0 {
|
||||
panic("no vpn networks")
|
||||
}
|
||||
|
||||
var udpAddr netip.AddrPort
|
||||
if vpnIpNet.Addr().Is4() {
|
||||
budpIp := vpnIpNet.Addr().As4()
|
||||
if vpnNetworks[0].Addr().Is4() {
|
||||
budpIp := vpnNetworks[0].Addr().As4()
|
||||
budpIp[1] -= 128
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
|
||||
} else {
|
||||
budpIp := vpnIpNet.Addr().As16()
|
||||
budpIp[13] -= 128
|
||||
budpIp := vpnNetworks[0].Addr().As16()
|
||||
// beef for funsies
|
||||
budpIp[2] = 190
|
||||
budpIp[3] = 239
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||
}
|
||||
_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
|
||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
||||
|
||||
caB, err := caCrt.MarshalToPEM()
|
||||
caB, err := caCrt.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -88,11 +100,16 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, s
|
||||
}
|
||||
|
||||
if overrides != nil {
|
||||
err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice)
|
||||
final := m{}
|
||||
err = mergo.Merge(&final, overrides, mergo.WithAppendSlice)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mc = overrides
|
||||
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mc = final
|
||||
}
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
@ -109,7 +126,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, s
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return control, vpnIpNet, udpAddr, c
|
||||
return control, vpnNetworks, udpAddr, c
|
||||
}
|
||||
|
||||
type doneCb func()
|
||||
@ -132,27 +149,28 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
|
||||
|
||||
func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
|
||||
// Send a packet from them to me
|
||||
controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
|
||||
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
|
||||
bPacket := r.RouteForAllUntilTxTun(controlA)
|
||||
assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
|
||||
|
||||
// And once more from me to them
|
||||
controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A"))
|
||||
controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))
|
||||
aPacket := r.RouteForAllUntilTxTun(controlB)
|
||||
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
|
||||
}
|
||||
|
||||
func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) {
|
||||
func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
|
||||
// Get both host infos
|
||||
hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false)
|
||||
assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
|
||||
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
|
||||
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
|
||||
assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
|
||||
|
||||
hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false)
|
||||
assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
|
||||
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
|
||||
assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
|
||||
|
||||
// Check that both vpn and real addr are correct
|
||||
assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
|
||||
assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
|
||||
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
|
||||
assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B")
|
||||
|
||||
assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
|
||||
assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
|
||||
@ -160,25 +178,36 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIp
|
||||
// Check that our indexes match
|
||||
assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
|
||||
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
|
||||
|
||||
//TODO: Would be nice to assert this memory
|
||||
//checkIndexes := func(name string, hm *HostMap, hi *HostInfo) {
|
||||
// hBbyIndex := hmA.Indexes[hBinA.localIndexId]
|
||||
// assert.NotNil(t, hBbyIndex, "Could not host info by local index in %s", name)
|
||||
// assert.Equal(t, &hBbyIndex, &hBinA, "%s Indexes map did not point to the right host info", name)
|
||||
//
|
||||
// //TODO: remote indexes are susceptible to collision
|
||||
// hBbyRemoteIndex := hmA.RemoteIndexes[hBinA.remoteIndexId]
|
||||
// assert.NotNil(t, hBbyIndex, "Could not host info by remote index in %s", name)
|
||||
// assert.Equal(t, &hBbyRemoteIndex, &hBinA, "%s RemoteIndexes did not point to the right host info", name)
|
||||
//}
|
||||
//
|
||||
//// Check hostmap indexes too
|
||||
//checkIndexes("hmA", hmA, hBinA)
|
||||
//checkIndexes("hmB", hmB, hAinB)
|
||||
}
|
||||
|
||||
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
||||
if toIp.Is6() {
|
||||
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
|
||||
} else {
|
||||
assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort)
|
||||
}
|
||||
}
|
||||
|
||||
func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
||||
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
|
||||
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
|
||||
assert.NotNil(t, v6, "No ipv6 data found")
|
||||
|
||||
assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect")
|
||||
assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect")
|
||||
|
||||
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
|
||||
assert.NotNil(t, udp, "No udp data found")
|
||||
|
||||
assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect")
|
||||
assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect")
|
||||
|
||||
data := packet.ApplicationLayer()
|
||||
assert.NotNil(t, data)
|
||||
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
|
||||
}
|
||||
|
||||
func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
||||
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
|
||||
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
|
||||
assert.NotNil(t, v4, "No ipv4 data found")
|
||||
@ -197,6 +226,14 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
|
||||
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
|
||||
}
|
||||
|
||||
func getAddrs(ns []netip.Prefix) []netip.Addr {
|
||||
var a []netip.Addr
|
||||
for _, n := range ns {
|
||||
a = append(a, n.Addr())
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func NewTestLogger() *logrus.Logger {
|
||||
l := logrus.New()
|
||||
|
||||
|
||||
@ -58,8 +58,9 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
|
||||
var lines []string
|
||||
var globalLines []*edge
|
||||
|
||||
clusterName := strings.Trim(c.GetCert().Details.Name, " ")
|
||||
clusterVpnIp := c.GetCert().Details.Ips[0].IP
|
||||
crt := c.GetCertState().GetDefaultCertificate()
|
||||
clusterName := strings.Trim(crt.Name(), " ")
|
||||
clusterVpnIp := crt.Networks()[0].Addr()
|
||||
r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp)
|
||||
|
||||
hm := c.GetHostmap()
|
||||
@ -101,8 +102,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
|
||||
for _, idx := range indexes {
|
||||
hi, ok := hm.Indexes[idx]
|
||||
if ok {
|
||||
r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp())
|
||||
remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ")
|
||||
r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs())
|
||||
remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ")
|
||||
globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
|
||||
_ = hi
|
||||
}
|
||||
|
||||
@ -10,8 +10,8 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@ -136,7 +136,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
|
||||
panic("Duplicate listen address: " + addr.String())
|
||||
}
|
||||
|
||||
r.vpnControls[c.GetVpnIp()] = c
|
||||
for _, vpnAddr := range c.GetVpnAddrs() {
|
||||
r.vpnControls[vpnAddr] = c
|
||||
}
|
||||
|
||||
r.controls[addr] = c
|
||||
}
|
||||
|
||||
@ -213,11 +216,11 @@ func (r *R) renderFlow() {
|
||||
continue
|
||||
}
|
||||
participants[addr] = struct{}{}
|
||||
sanAddr := strings.Replace(addr.String(), ":", "-", 1)
|
||||
sanAddr := normalizeName(addr.String())
|
||||
participantsVals = append(participantsVals, sanAddr)
|
||||
fmt.Fprintf(
|
||||
f, " participant %s as Nebula: %s<br/>UDP: %s\n",
|
||||
sanAddr, e.packet.from.GetVpnIp(), sanAddr,
|
||||
sanAddr, e.packet.from.GetVpnAddrs(), sanAddr,
|
||||
)
|
||||
}
|
||||
|
||||
@ -250,9 +253,9 @@ func (r *R) renderFlow() {
|
||||
|
||||
fmt.Fprintf(f,
|
||||
" %s%s%s: %s(%s), index %v, counter: %v\n",
|
||||
strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1),
|
||||
normalizeName(p.from.GetUDPAddr().String()),
|
||||
line,
|
||||
strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
|
||||
normalizeName(p.to.GetUDPAddr().String()),
|
||||
h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
|
||||
)
|
||||
}
|
||||
@ -267,6 +270,11 @@ func (r *R) renderFlow() {
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeName(s string) string {
|
||||
rx := regexp.MustCompile("[\\[\\]\\:]")
|
||||
return rx.ReplaceAllLiteralString(s, "_")
|
||||
}
|
||||
|
||||
// IgnoreFlow tells the router to stop recording future flows that matches the provided criteria.
|
||||
// messageType and subType will target nebula underlay packets while tun will target nebula overlay packets
|
||||
// NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered
|
||||
@ -303,7 +311,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
|
||||
func (r *R) renderHostmaps(title string) {
|
||||
c := maps.Values(r.controls)
|
||||
sort.SliceStable(c, func(i, j int) bool {
|
||||
return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0
|
||||
return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0
|
||||
})
|
||||
|
||||
s := renderHostmaps(c...)
|
||||
@ -419,10 +427,11 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
|
||||
// Nope, lets push the sender along
|
||||
case p := <-udpTx:
|
||||
r.Lock()
|
||||
c := r.getControl(sender.GetUDPAddr(), p.To, p)
|
||||
a := sender.GetUDPAddr()
|
||||
c := r.getControl(a, p.To, p)
|
||||
if c == nil {
|
||||
r.Unlock()
|
||||
panic("No control for udp tx")
|
||||
panic("No control for udp tx " + a.String())
|
||||
}
|
||||
fp := r.unlockedInjectFlow(sender, c, p, false)
|
||||
c.InjectUDPPacket(p)
|
||||
@ -475,10 +484,11 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
|
||||
} else {
|
||||
// we are a udp tx, route and continue
|
||||
p := rx.Interface().(*udp.Packet)
|
||||
c := r.getControl(cm[x].GetUDPAddr(), p.To, p)
|
||||
a := cm[x].GetUDPAddr()
|
||||
c := r.getControl(a, p.To, p)
|
||||
if c == nil {
|
||||
r.Unlock()
|
||||
panic("No control for udp tx")
|
||||
panic(fmt.Sprintf("No control for udp tx %s", p.To))
|
||||
}
|
||||
fp := r.unlockedInjectFlow(cm[x], c, p, false)
|
||||
c.InjectUDPPacket(p)
|
||||
@ -711,30 +721,42 @@ func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.C
|
||||
}
|
||||
|
||||
func (r *R) formatUdpPacket(p *packet) string {
|
||||
packet := gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
|
||||
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
|
||||
if v4 == nil {
|
||||
panic("not an ipv4 packet")
|
||||
var packet gopacket.Packet
|
||||
var srcAddr netip.Addr
|
||||
|
||||
packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy)
|
||||
if packet.ErrorLayer() == nil {
|
||||
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
|
||||
if v6 == nil {
|
||||
panic("not an ipv6 packet")
|
||||
}
|
||||
srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
|
||||
} else {
|
||||
packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
|
||||
v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
|
||||
if v6 == nil {
|
||||
panic("not an ipv6 packet")
|
||||
}
|
||||
srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
|
||||
}
|
||||
|
||||
from := "unknown"
|
||||
srcAddr, _ := netip.AddrFromSlice(v4.SrcIP)
|
||||
if c, ok := r.vpnControls[srcAddr]; ok {
|
||||
from = c.GetUDPAddr().String()
|
||||
}
|
||||
|
||||
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
|
||||
if udp == nil {
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
|
||||
if udpLayer == nil {
|
||||
panic("not a udp packet")
|
||||
}
|
||||
|
||||
data := packet.ApplicationLayer()
|
||||
return fmt.Sprintf(
|
||||
" %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
|
||||
strings.Replace(from, ":", "-", 1),
|
||||
strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
|
||||
udp.SrcPort,
|
||||
udp.DstPort,
|
||||
normalizeName(from),
|
||||
normalizeName(p.to.GetUDPAddr().String()),
|
||||
udpLayer.SrcPort,
|
||||
udpLayer.DstPort,
|
||||
string(data.Payload()),
|
||||
)
|
||||
}
|
||||
|
||||
@ -13,6 +13,12 @@ pki:
|
||||
# disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
|
||||
#disconnect_invalid: true
|
||||
|
||||
# default_version controls which certificate version is used in handshakes.
|
||||
# This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
|
||||
# Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
|
||||
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
|
||||
# default_version: 1
|
||||
|
||||
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
|
||||
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
|
||||
# The syntax is:
|
||||
@ -120,8 +126,8 @@ lighthouse:
|
||||
# Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined,
|
||||
# however using port 0 will dynamically assign a port and is recommended for roaming nodes.
|
||||
listen:
|
||||
# To listen on both any ipv4 and ipv6 use "::"
|
||||
host: 0.0.0.0
|
||||
# To listen on only ipv4, use "0.0.0.0"
|
||||
host: "::"
|
||||
port: 4242
|
||||
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
|
||||
# default is 64, does not support reload
|
||||
@ -138,6 +144,11 @@ listen:
|
||||
# valid values: always, never, private
|
||||
# This setting is reloadable.
|
||||
#send_recv_error: always
|
||||
# The so_sock option is a Linux-specific feature that allows all outgoing Nebula packets to be tagged with a specific identifier.
|
||||
# This tagging enables IP rule-based filtering. For example, it supports 0.0.0.0/0 unsafe_routes,
|
||||
# allowing for more precise routing decisions based on the packet tags. Default is 0 meaning no mark is set.
|
||||
# This setting is reloadable.
|
||||
#so_mark: 0
|
||||
|
||||
# Routines is the number of thread pairs to run that consume from the tun and UDP queues.
|
||||
# Currently, this defaults to 1 which means we have 1 tun queue reader and 1
|
||||
@ -244,7 +255,6 @@ tun:
|
||||
# in nebula configuration files. Default false, not reloadable.
|
||||
#use_system_route_table: false
|
||||
|
||||
# TODO
|
||||
# Configure logging level
|
||||
logging:
|
||||
# panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
|
||||
@ -336,10 +346,12 @@ firewall:
|
||||
# host: `any` or a literal hostname, ie `test-host`
|
||||
# group: `any` or a literal group name, ie `default-group`
|
||||
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
||||
# cidr: a remote CIDR, `0.0.0.0/0` is any.
|
||||
# local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
|
||||
# Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate
|
||||
# if `default_local_cidr_any` is false, otherwise its `any`.
|
||||
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
|
||||
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes.
|
||||
# If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network.
|
||||
# Otherwise the default is any vpn network assigned to via the certificate.
|
||||
# `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release.
|
||||
# If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation.
|
||||
# ca_name: An issuing CA name
|
||||
# ca_sha: An issuing CA shasum
|
||||
|
||||
|
||||
190
firewall.go
190
firewall.go
@ -22,7 +22,7 @@ import (
|
||||
)
|
||||
|
||||
type FirewallInterface interface {
|
||||
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
|
||||
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
@ -51,10 +51,13 @@ type Firewall struct {
|
||||
UDPTimeout time.Duration //linux: 180s max
|
||||
DefaultTimeout time.Duration //linux: 600s
|
||||
|
||||
// Used to ensure we don't emit local packets for ips we don't own
|
||||
localIps *bart.Table[struct{}]
|
||||
assignedCIDR netip.Prefix
|
||||
hasSubnets bool
|
||||
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
|
||||
// The vpn addresses are a full bit match while the unsafe networks only match the prefix
|
||||
routableNetworks *bart.Table[struct{}]
|
||||
|
||||
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
|
||||
assignedNetworks []netip.Prefix
|
||||
hasUnsafeNetworks bool
|
||||
|
||||
rules string
|
||||
rulesVersion uint16
|
||||
@ -67,9 +70,9 @@ type Firewall struct {
|
||||
}
|
||||
|
||||
type firewallMetrics struct {
|
||||
droppedLocalIP metrics.Counter
|
||||
droppedRemoteIP metrics.Counter
|
||||
droppedNoRule metrics.Counter
|
||||
droppedLocalAddr metrics.Counter
|
||||
droppedRemoteAddr metrics.Counter
|
||||
droppedNoRule metrics.Counter
|
||||
}
|
||||
|
||||
type FirewallConntrack struct {
|
||||
@ -126,88 +129,87 @@ type firewallLocalCIDR struct {
|
||||
}
|
||||
|
||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
|
||||
// The certificate provided should be the highest version loaded in memory.
|
||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
|
||||
//TODO: error on 0 duration
|
||||
var min, max time.Duration
|
||||
var tmin, tmax time.Duration
|
||||
|
||||
if tcpTimeout < UDPTimeout {
|
||||
min = tcpTimeout
|
||||
max = UDPTimeout
|
||||
tmin = tcpTimeout
|
||||
tmax = UDPTimeout
|
||||
} else {
|
||||
min = UDPTimeout
|
||||
max = tcpTimeout
|
||||
tmin = UDPTimeout
|
||||
tmax = tcpTimeout
|
||||
}
|
||||
|
||||
if defaultTimeout < min {
|
||||
min = defaultTimeout
|
||||
} else if defaultTimeout > max {
|
||||
max = defaultTimeout
|
||||
if defaultTimeout < tmin {
|
||||
tmin = defaultTimeout
|
||||
} else if defaultTimeout > tmax {
|
||||
tmax = defaultTimeout
|
||||
}
|
||||
|
||||
localIps := new(bart.Table[struct{}])
|
||||
var assignedCIDR netip.Prefix
|
||||
var assignedSet bool
|
||||
for _, ip := range c.Details.Ips {
|
||||
//TODO: IPV6-WORK the unmap is a bit unfortunate
|
||||
nip, _ := netip.AddrFromSlice(ip.IP)
|
||||
nip = nip.Unmap()
|
||||
nprefix := netip.PrefixFrom(nip, nip.BitLen())
|
||||
localIps.Insert(nprefix, struct{}{})
|
||||
|
||||
if !assignedSet {
|
||||
// Only grabbing the first one in the cert since any more than that currently has undefined behavior
|
||||
assignedCIDR = nprefix
|
||||
assignedSet = true
|
||||
}
|
||||
routableNetworks := new(bart.Table[struct{}])
|
||||
var assignedNetworks []netip.Prefix
|
||||
for _, network := range c.Networks() {
|
||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||
routableNetworks.Insert(nprefix, struct{}{})
|
||||
assignedNetworks = append(assignedNetworks, network)
|
||||
}
|
||||
|
||||
for _, n := range c.Details.Subnets {
|
||||
nip, _ := netip.AddrFromSlice(n.IP)
|
||||
ones, _ := n.Mask.Size()
|
||||
nip = nip.Unmap()
|
||||
localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{})
|
||||
hasUnsafeNetworks := false
|
||||
for _, n := range c.UnsafeNetworks() {
|
||||
routableNetworks.Insert(n, struct{}{})
|
||||
hasUnsafeNetworks = true
|
||||
}
|
||||
|
||||
return &Firewall{
|
||||
Conntrack: &FirewallConntrack{
|
||||
Conns: make(map[firewall.Packet]*conn),
|
||||
TimerWheel: NewTimerWheel[firewall.Packet](min, max),
|
||||
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
|
||||
},
|
||||
InRules: newFirewallTable(),
|
||||
OutRules: newFirewallTable(),
|
||||
TCPTimeout: tcpTimeout,
|
||||
UDPTimeout: UDPTimeout,
|
||||
DefaultTimeout: defaultTimeout,
|
||||
localIps: localIps,
|
||||
assignedCIDR: assignedCIDR,
|
||||
hasSubnets: len(c.Details.Subnets) > 0,
|
||||
l: l,
|
||||
InRules: newFirewallTable(),
|
||||
OutRules: newFirewallTable(),
|
||||
TCPTimeout: tcpTimeout,
|
||||
UDPTimeout: UDPTimeout,
|
||||
DefaultTimeout: defaultTimeout,
|
||||
routableNetworks: routableNetworks,
|
||||
assignedNetworks: assignedNetworks,
|
||||
hasUnsafeNetworks: hasUnsafeNetworks,
|
||||
l: l,
|
||||
|
||||
incomingMetrics: firewallMetrics{
|
||||
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
|
||||
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil),
|
||||
droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
|
||||
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
|
||||
droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil),
|
||||
droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
|
||||
},
|
||||
outgoingMetrics: firewallMetrics{
|
||||
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil),
|
||||
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil),
|
||||
droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
|
||||
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil),
|
||||
droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil),
|
||||
droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) {
|
||||
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
|
||||
certificate := cs.getCertificate(cert.Version2)
|
||||
if certificate == nil {
|
||||
certificate = cs.getCertificate(cert.Version1)
|
||||
}
|
||||
|
||||
if certificate == nil {
|
||||
panic("No certificate available to reconfigure the firewall")
|
||||
}
|
||||
|
||||
fw := NewFirewall(
|
||||
l,
|
||||
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
|
||||
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
|
||||
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
|
||||
nc,
|
||||
certificate,
|
||||
//TODO: max_connections
|
||||
)
|
||||
|
||||
//TODO: Flip to false after v1.9 release
|
||||
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true)
|
||||
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
|
||||
|
||||
inboundAction := c.GetString("firewall.inbound_action", "drop")
|
||||
switch inboundAction {
|
||||
@ -287,7 +289,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
||||
fp = ft.TCP
|
||||
case firewall.ProtoUDP:
|
||||
fp = ft.UDP
|
||||
case firewall.ProtoICMP:
|
||||
case firewall.ProtoICMP, firewall.ProtoICMPv6:
|
||||
fp = ft.ICMP
|
||||
case firewall.ProtoAny:
|
||||
fp = ft.AnyProto
|
||||
@ -421,33 +423,31 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||
|
||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||
// returns nil if the packet should not be dropped.
|
||||
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
|
||||
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
|
||||
// Check if we spoke to this tuple, if we did then allow this packet
|
||||
if f.inConns(fp, h, caPool, localCache) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make sure remote address matches nebula certificate
|
||||
if remoteCidr := h.remoteCidr; remoteCidr != nil {
|
||||
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
|
||||
_, ok := remoteCidr.Lookup(fp.RemoteIP)
|
||||
if h.networks != nil {
|
||||
_, ok := h.networks.Lookup(fp.RemoteAddr)
|
||||
if !ok {
|
||||
f.metrics(incoming).droppedRemoteIP.Inc(1)
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
} else {
|
||||
// Simple case: Certificate has one IP and no subnets
|
||||
if fp.RemoteIP != h.vpnIp {
|
||||
f.metrics(incoming).droppedRemoteIP.Inc(1)
|
||||
// Simple case: Certificate has one address and no unsafe networks
|
||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure we are supposed to be handling this local ip address
|
||||
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
|
||||
_, ok := f.localIps.Lookup(fp.LocalIP)
|
||||
_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
|
||||
if !ok {
|
||||
f.metrics(incoming).droppedLocalIP.Inc(1)
|
||||
f.metrics(incoming).droppedLocalAddr.Inc(1)
|
||||
return ErrInvalidLocalIP
|
||||
}
|
||||
|
||||
@ -492,7 +492,7 @@ func (f *Firewall) EmitStats() {
|
||||
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
||||
}
|
||||
|
||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
|
||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
|
||||
if localCache != nil {
|
||||
if _, ok := localCache[fp]; ok {
|
||||
return true
|
||||
@ -619,7 +619,7 @@ func (f *Firewall) evict(p firewall.Packet) {
|
||||
delete(conntrack.Conns, p)
|
||||
}
|
||||
|
||||
func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||
func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
|
||||
if ft.AnyProto.match(p, incoming, c, caPool) {
|
||||
return true
|
||||
}
|
||||
@ -633,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
|
||||
if ft.UDP.match(p, incoming, c, caPool) {
|
||||
return true
|
||||
}
|
||||
case firewall.ProtoICMP:
|
||||
case firewall.ProtoICMP, firewall.ProtoICMPv6:
|
||||
if ft.ICMP.match(p, incoming, c, caPool) {
|
||||
return true
|
||||
}
|
||||
@ -663,7 +663,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||
func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
|
||||
// We don't have any allowed ports, bail
|
||||
if fp == nil {
|
||||
return false
|
||||
@ -726,7 +726,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||
func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
|
||||
if fc == nil {
|
||||
return false
|
||||
}
|
||||
@ -735,18 +735,18 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
|
||||
return true
|
||||
}
|
||||
|
||||
if t, ok := fc.CAShas[c.Details.Issuer]; ok {
|
||||
if t, ok := fc.CAShas[c.Certificate.Issuer()]; ok {
|
||||
if t.match(p, c) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
s, err := caPool.GetCAForCert(c)
|
||||
s, err := caPool.GetCAForCert(c.Certificate)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return fc.CANames[s.Details.Name].match(p, c)
|
||||
return fc.CANames[s.Certificate.Name()].match(p, c)
|
||||
}
|
||||
|
||||
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
|
||||
@ -826,7 +826,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) boo
|
||||
return false
|
||||
}
|
||||
|
||||
func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
|
||||
func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool {
|
||||
if fr == nil {
|
||||
return false
|
||||
}
|
||||
@ -841,7 +841,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
|
||||
found := false
|
||||
|
||||
for _, g := range sg.Groups {
|
||||
if _, ok := c.Details.InvertedGroups[g]; !ok {
|
||||
if _, ok := c.InvertedGroups[g]; !ok {
|
||||
found = false
|
||||
break
|
||||
}
|
||||
@ -855,42 +855,44 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
|
||||
}
|
||||
|
||||
if fr.Hosts != nil {
|
||||
if flc, ok := fr.Hosts[c.Details.Name]; ok {
|
||||
if flc, ok := fr.Hosts[c.Certificate.Name()]; ok {
|
||||
if flc.match(p, c) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
matched := false
|
||||
prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
|
||||
fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
|
||||
if prefix.Contains(p.RemoteIP) && val.match(p, c) {
|
||||
matched = true
|
||||
return false
|
||||
for _, v := range fr.CIDR.Supernets(netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())) {
|
||||
if v.match(p, c) {
|
||||
return true
|
||||
}
|
||||
return true
|
||||
})
|
||||
return matched
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
|
||||
if !localIp.IsValid() {
|
||||
if !f.hasSubnets || f.defaultLocalCIDRAny {
|
||||
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
|
||||
flc.Any = true
|
||||
return nil
|
||||
}
|
||||
|
||||
localIp = f.assignedCIDR
|
||||
for _, network := range f.assignedNetworks {
|
||||
flc.LocalCIDR.Insert(network, struct{}{})
|
||||
}
|
||||
return nil
|
||||
|
||||
} else if localIp.Bits() == 0 {
|
||||
flc.Any = true
|
||||
return nil
|
||||
}
|
||||
|
||||
flc.LocalCIDR.Insert(localIp, struct{}{})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
|
||||
func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate) bool {
|
||||
if flc == nil {
|
||||
return false
|
||||
}
|
||||
@ -899,7 +901,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate
|
||||
return true
|
||||
}
|
||||
|
||||
_, ok := flc.LocalCIDR.Lookup(p.LocalIP)
|
||||
_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
|
||||
return ok
|
||||
}
|
||||
|
||||
|
||||
@ -9,18 +9,19 @@ import (
|
||||
type m map[string]interface{}
|
||||
|
||||
const (
|
||||
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
|
||||
ProtoTCP = 6
|
||||
ProtoUDP = 17
|
||||
ProtoICMP = 1
|
||||
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
|
||||
ProtoTCP = 6
|
||||
ProtoUDP = 17
|
||||
ProtoICMP = 1
|
||||
ProtoICMPv6 = 58
|
||||
|
||||
PortAny = 0 // Special value for matching `port: any`
|
||||
PortFragment = -1 // Special value for matching `port: fragment`
|
||||
)
|
||||
|
||||
type Packet struct {
|
||||
LocalIP netip.Addr
|
||||
RemoteIP netip.Addr
|
||||
LocalAddr netip.Addr
|
||||
RemoteAddr netip.Addr
|
||||
LocalPort uint16
|
||||
RemotePort uint16
|
||||
Protocol uint8
|
||||
@ -29,8 +30,8 @@ type Packet struct {
|
||||
|
||||
func (fp *Packet) Copy() *Packet {
|
||||
return &Packet{
|
||||
LocalIP: fp.LocalIP,
|
||||
RemoteIP: fp.RemoteIP,
|
||||
LocalAddr: fp.LocalAddr,
|
||||
RemoteAddr: fp.RemoteAddr,
|
||||
LocalPort: fp.LocalPort,
|
||||
RemotePort: fp.RemotePort,
|
||||
Protocol: fp.Protocol,
|
||||
@ -51,8 +52,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
|
||||
proto = fmt.Sprintf("unknown %v", fp.Protocol)
|
||||
}
|
||||
return json.Marshal(m{
|
||||
"LocalIP": fp.LocalIP.String(),
|
||||
"RemoteIP": fp.RemoteIP.String(),
|
||||
"LocalAddr": fp.LocalAddr.String(),
|
||||
"RemoteAddr": fp.RemoteAddr.String(),
|
||||
"LocalPort": fp.LocalPort,
|
||||
"RemotePort": fp.RemotePort,
|
||||
"Protocol": proto,
|
||||
|
||||
475
firewall_test.go
475
firewall_test.go
@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
@ -14,11 +13,12 @@ import (
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewFirewall(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := &cert.NebulaCertificate{}
|
||||
c := &dummyCert{}
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
conntrack := fw.Conntrack
|
||||
assert.NotNil(t, conntrack)
|
||||
@ -60,67 +60,67 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
c := &cert.NebulaCertificate{}
|
||||
c := &dummyCert{}
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.NotNil(t, fw.InRules)
|
||||
assert.NotNil(t, fw.OutRules)
|
||||
|
||||
ti, err := netip.ParsePrefix("1.2.3.4/32")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
// An empty rule is any
|
||||
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
||||
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
|
||||
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
|
||||
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
|
||||
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
|
||||
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
||||
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
|
||||
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
|
||||
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
|
||||
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
anyIp, err := netip.ParsePrefix("0.0.0.0/0")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
|
||||
// Test error conditions
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
}
|
||||
|
||||
func TestFirewall_Drop(t *testing.T) {
|
||||
@ -129,79 +129,74 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
l.SetOutput(ob)
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalIP: netip.MustParseAddr("1.2.3.4"),
|
||||
RemoteIP: netip.MustParseAddr("1.2.3.4"),
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
LocalPort: 10,
|
||||
RemotePort: 90,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
|
||||
ipNet := net.IPNet{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
|
||||
c := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "host1",
|
||||
Ips: []*net.IPNet{&ipNet},
|
||||
Groups: []string{"default-group"},
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||
Issuer: "signer-shasum",
|
||||
},
|
||||
c := dummyCert{
|
||||
name: "host1",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")},
|
||||
groups: []string{"default-group"},
|
||||
issuer: "signer-shasum",
|
||||
}
|
||||
h := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c,
|
||||
peerCert: &cert.CachedCertificate{
|
||||
Certificate: &c,
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||
},
|
||||
},
|
||||
vpnIp: netip.MustParseAddr("1.2.3.4"),
|
||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
||||
}
|
||||
h.CreateRemoteCIDR(&c)
|
||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
// Allow outbound because conntrack
|
||||
assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
|
||||
// test remote mismatch
|
||||
oldRemote := p.RemoteIP
|
||||
p.RemoteIP = netip.MustParseAddr("1.2.3.10")
|
||||
oldRemote := p.RemoteAddr
|
||||
p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
p.RemoteIP = oldRemote
|
||||
p.RemoteAddr = oldRemote
|
||||
|
||||
// ensure signer doesn't get in the way of group checks
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caSha doesn't drop on match
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
|
||||
// ensure ca name doesn't get in the way of group checks
|
||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caName doesn't drop on match
|
||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
@ -217,7 +212,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
|
||||
b.Run("fail on proto", func(b *testing.B) {
|
||||
// This benchmark is showing us the cost of failing to match the protocol
|
||||
c := &cert.NebulaCertificate{}
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
|
||||
}
|
||||
@ -225,28 +222,31 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
|
||||
b.Run("pass proto, fail on port", func(b *testing.B) {
|
||||
// This benchmark is showing us the cost of matching a specific protocol but failing to match the port
|
||||
c := &cert.NebulaCertificate{}
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
|
||||
c := &cert.NebulaCertificate{}
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{},
|
||||
}
|
||||
ip := netip.MustParsePrefix("9.254.254.254/32")
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp))
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
||||
_, ip, _ := net.ParseCIDR("9.254.254.254/32")
|
||||
c := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
Name: "nope",
|
||||
Ips: []*net.IPNet{ip},
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "nope",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
||||
@ -254,25 +254,24 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
||||
_, ip, _ := net.ParseCIDR("9.254.254.254/32")
|
||||
c := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
Name: "nope",
|
||||
Ips: []*net.IPNet{ip},
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "nope",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass on group on any local cidr", func(b *testing.B) {
|
||||
c := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
InvertedGroups: map[string]struct{}{"good-group": {}},
|
||||
Name: "nope",
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "nope",
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"good-group": {}},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
||||
@ -280,82 +279,28 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("pass on group on specific local cidr", func(b *testing.B) {
|
||||
c := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
InvertedGroups: map[string]struct{}{"good-group": {}},
|
||||
Name: "nope",
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "nope",
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"good-group": {}},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
|
||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass on name", func(b *testing.B) {
|
||||
c := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
Name: "good-host",
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "good-host",
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
|
||||
}
|
||||
})
|
||||
//
|
||||
//b.Run("pass on ip", func(b *testing.B) {
|
||||
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||
// c := &cert.NebulaCertificate{
|
||||
// Details: cert.NebulaCertificateDetails{
|
||||
// InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
// Name: "good-host",
|
||||
// },
|
||||
// }
|
||||
// for n := 0; n < b.N; n++ {
|
||||
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
|
||||
// }
|
||||
//})
|
||||
//
|
||||
//b.Run("pass on local ip", func(b *testing.B) {
|
||||
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||
// c := &cert.NebulaCertificate{
|
||||
// Details: cert.NebulaCertificateDetails{
|
||||
// InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
// Name: "good-host",
|
||||
// },
|
||||
// }
|
||||
// for n := 0; n < b.N; n++ {
|
||||
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
|
||||
// }
|
||||
//})
|
||||
//
|
||||
//_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
|
||||
//
|
||||
//b.Run("pass on ip with any port", func(b *testing.B) {
|
||||
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||
// c := &cert.NebulaCertificate{
|
||||
// Details: cert.NebulaCertificateDetails{
|
||||
// InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
// Name: "good-host",
|
||||
// },
|
||||
// }
|
||||
// for n := 0; n < b.N; n++ {
|
||||
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
|
||||
// }
|
||||
//})
|
||||
//
|
||||
//b.Run("pass on local ip with any port", func(b *testing.B) {
|
||||
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||
// c := &cert.NebulaCertificate{
|
||||
// Details: cert.NebulaCertificateDetails{
|
||||
// InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
// Name: "good-host",
|
||||
// },
|
||||
// }
|
||||
// for n := 0; n < b.N; n++ {
|
||||
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
|
||||
// }
|
||||
//})
|
||||
}
|
||||
|
||||
func TestFirewall_Drop2(t *testing.T) {
|
||||
@ -364,57 +309,55 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
l.SetOutput(ob)
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalIP: netip.MustParseAddr("1.2.3.4"),
|
||||
RemoteIP: netip.MustParseAddr("1.2.3.4"),
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
LocalPort: 10,
|
||||
RemotePort: 90,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
|
||||
ipNet := net.IPNet{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
network := netip.MustParsePrefix("1.2.3.4/24")
|
||||
|
||||
c := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "host1",
|
||||
Ips: []*net.IPNet{&ipNet},
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
|
||||
c := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host1",
|
||||
networks: []netip.Prefix{network},
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
|
||||
}
|
||||
h := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c,
|
||||
},
|
||||
vpnIp: netip.MustParseAddr(ipNet.IP.String()),
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h.CreateRemoteCIDR(&c)
|
||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
|
||||
c1 := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "host1",
|
||||
Ips: []*net.IPNet{&ipNet},
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
|
||||
c1 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host1",
|
||||
networks: []netip.Prefix{network},
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
|
||||
}
|
||||
h1 := HostInfo{
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c1,
|
||||
},
|
||||
}
|
||||
h1.CreateRemoteCIDR(&c1)
|
||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// h1/c1 lacks the proper groups
|
||||
assert.Error(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
|
||||
require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
|
||||
// c has the proper groups
|
||||
resetConntrack(fw)
|
||||
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_Drop3(t *testing.T) {
|
||||
@ -423,84 +366,85 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
l.SetOutput(ob)
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalIP: netip.MustParseAddr("1.2.3.4"),
|
||||
RemoteIP: netip.MustParseAddr("1.2.3.4"),
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
LocalPort: 1,
|
||||
RemotePort: 1,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
|
||||
ipNet := net.IPNet{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
|
||||
c := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "host-owner",
|
||||
Ips: []*net.IPNet{&ipNet},
|
||||
network := netip.MustParsePrefix("1.2.3.4/24")
|
||||
c := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host-owner",
|
||||
networks: []netip.Prefix{network},
|
||||
},
|
||||
}
|
||||
|
||||
c1 := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "host1",
|
||||
Ips: []*net.IPNet{&ipNet},
|
||||
Issuer: "signer-sha-bad",
|
||||
c1 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host1",
|
||||
networks: []netip.Prefix{network},
|
||||
issuer: "signer-sha-bad",
|
||||
},
|
||||
}
|
||||
h1 := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c1,
|
||||
},
|
||||
vpnIp: netip.MustParseAddr(ipNet.IP.String()),
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h1.CreateRemoteCIDR(&c1)
|
||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
|
||||
c2 := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "host2",
|
||||
Ips: []*net.IPNet{&ipNet},
|
||||
Issuer: "signer-sha",
|
||||
c2 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host2",
|
||||
networks: []netip.Prefix{network},
|
||||
issuer: "signer-sha",
|
||||
},
|
||||
}
|
||||
h2 := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c2,
|
||||
},
|
||||
vpnIp: netip.MustParseAddr(ipNet.IP.String()),
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h2.CreateRemoteCIDR(&c2)
|
||||
h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
|
||||
|
||||
c3 := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "host3",
|
||||
Ips: []*net.IPNet{&ipNet},
|
||||
Issuer: "signer-sha-bad",
|
||||
c3 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host3",
|
||||
networks: []netip.Prefix{network},
|
||||
issuer: "signer-sha-bad",
|
||||
},
|
||||
}
|
||||
h3 := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c3,
|
||||
},
|
||||
vpnIp: netip.MustParseAddr(ipNet.IP.String()),
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h3.CreateRemoteCIDR(&c3)
|
||||
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// c1 should pass because host match
|
||||
assert.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||
// c2 should pass because ca sha match
|
||||
resetConntrack(fw)
|
||||
assert.NoError(t, fw.Drop(p, true, &h2, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
|
||||
// c3 should fail because no match
|
||||
resetConntrack(fw)
|
||||
assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// Test a remote address match
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
@ -509,60 +453,56 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
l.SetOutput(ob)
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalIP: netip.MustParseAddr("1.2.3.4"),
|
||||
RemoteIP: netip.MustParseAddr("1.2.3.4"),
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
LocalPort: 10,
|
||||
RemotePort: 90,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
network := netip.MustParsePrefix("1.2.3.4/24")
|
||||
|
||||
ipNet := net.IPNet{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
|
||||
c := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "host1",
|
||||
Ips: []*net.IPNet{&ipNet},
|
||||
Groups: []string{"default-group"},
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||
Issuer: "signer-shasum",
|
||||
c := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host1",
|
||||
networks: []netip.Prefix{network},
|
||||
groups: []string{"default-group"},
|
||||
issuer: "signer-shasum",
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||
}
|
||||
h := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c,
|
||||
},
|
||||
vpnIp: netip.MustParseAddr(ipNet.IP.String()),
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h.CreateRemoteCIDR(&c)
|
||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
// Allow outbound because conntrack
|
||||
assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
|
||||
oldFw := fw
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
fw.Conntrack = oldFw.Conntrack
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
// Allow outbound because conntrack and new rules allow port 10
|
||||
assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
|
||||
oldFw = fw
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
fw.Conntrack = oldFw.Conntrack
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
@ -641,104 +581,105 @@ func BenchmarkLookup(b *testing.B) {
|
||||
ml(m, a)
|
||||
}
|
||||
})
|
||||
|
||||
//TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
|
||||
}
|
||||
|
||||
func Test_parsePort(t *testing.T) {
|
||||
_, _, err := parsePort("")
|
||||
assert.EqualError(t, err, "was not a number; ``")
|
||||
require.EqualError(t, err, "was not a number; ``")
|
||||
|
||||
_, _, err = parsePort(" ")
|
||||
assert.EqualError(t, err, "was not a number; ` `")
|
||||
require.EqualError(t, err, "was not a number; ` `")
|
||||
|
||||
_, _, err = parsePort("-")
|
||||
assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
|
||||
require.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
|
||||
|
||||
_, _, err = parsePort(" - ")
|
||||
assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
|
||||
require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
|
||||
|
||||
_, _, err = parsePort("a-b")
|
||||
assert.EqualError(t, err, "beginning range was not a number; `a`")
|
||||
require.EqualError(t, err, "beginning range was not a number; `a`")
|
||||
|
||||
_, _, err = parsePort("1-b")
|
||||
assert.EqualError(t, err, "ending range was not a number; `b`")
|
||||
require.EqualError(t, err, "ending range was not a number; `b`")
|
||||
|
||||
s, e, err := parsePort(" 1 - 2 ")
|
||||
assert.Equal(t, int32(1), s)
|
||||
assert.Equal(t, int32(2), e)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, e, err = parsePort("0-1")
|
||||
assert.Equal(t, int32(0), s)
|
||||
assert.Equal(t, int32(0), e)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, e, err = parsePort("9919")
|
||||
assert.Equal(t, int32(9919), s)
|
||||
assert.Equal(t, int32(9919), e)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, e, err = parsePort("any")
|
||||
assert.Equal(t, int32(0), s)
|
||||
assert.Equal(t, int32(0), e)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestNewFirewallFromConfig(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
// Test a bad rule definition
|
||||
c := &cert.NebulaCertificate{}
|
||||
c := &dummyCert{}
|
||||
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
conf := config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
|
||||
_, err := NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
||||
|
||||
// Test both port and code
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
||||
|
||||
// Test missing host, group, cidr, ca_name and ca_sha
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
|
||||
|
||||
// Test code/port error
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
||||
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
||||
|
||||
// Test proto error
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
||||
|
||||
// Test cidr parse error
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||
|
||||
// Test local_cidr parse error
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||
|
||||
// Test both group and groups
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
||||
}
|
||||
|
||||
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
@ -747,28 +688,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
conf := config.NewC(l)
|
||||
mf := &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding udp rule
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding icmp rule
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding any rule
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding rule with cidr
|
||||
@ -776,49 +717,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding rule with local_cidr
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
|
||||
|
||||
// Test adding rule with ca_sha
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
|
||||
|
||||
// Test adding rule with ca_name
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
|
||||
|
||||
// Test single group
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test single groups
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test multiple AND groups
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test Add error
|
||||
@ -826,7 +767,7 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
mf = &mockFirewall{}
|
||||
mf.nextCallReturn = errors.New("test error")
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
||||
assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
|
||||
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
|
||||
}
|
||||
|
||||
func TestFirewall_convertRule(t *testing.T) {
|
||||
@ -841,7 +782,7 @@ func TestFirewall_convertRule(t *testing.T) {
|
||||
|
||||
r, err := convertRule(l, c, "test", 1)
|
||||
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "group1", r.Group)
|
||||
|
||||
// Ensure group array of > 1 is errord
|
||||
@ -852,7 +793,7 @@ func TestFirewall_convertRule(t *testing.T) {
|
||||
|
||||
r, err = convertRule(l, c, "test", 1)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
|
||||
require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
|
||||
|
||||
// Make sure a well formed group is alright
|
||||
ob.Reset()
|
||||
@ -861,7 +802,7 @@ func TestFirewall_convertRule(t *testing.T) {
|
||||
}
|
||||
|
||||
r, err = convertRule(l, c, "test", 1)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "group1", r.Group)
|
||||
}
|
||||
|
||||
|
||||
38
go.mod
38
go.mod
@ -1,8 +1,8 @@
|
||||
module github.com/slackhq/nebula
|
||||
|
||||
go 1.22.0
|
||||
go 1.23.6
|
||||
|
||||
toolchain go1.22.2
|
||||
toolchain go1.23.7
|
||||
|
||||
require (
|
||||
dario.cat/mergo v1.0.1
|
||||
@ -10,45 +10,45 @@ require (
|
||||
github.com/armon/go-radix v1.0.0
|
||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
||||
github.com/flynn/noise v1.1.0
|
||||
github.com/gaissmai/bart v0.11.1
|
||||
github.com/gaissmai/bart v0.18.1
|
||||
github.com/gogo/protobuf v1.3.2
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/kardianos/service v1.2.2
|
||||
github.com/miekg/dns v1.1.61
|
||||
github.com/miekg/dns v1.1.63
|
||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
|
||||
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/prometheus/client_golang v1.21.1
|
||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
|
||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2
|
||||
golang.org/x/crypto v0.26.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/vishvananda/netlink v1.3.0
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||
golang.org/x/net v0.28.0
|
||||
golang.org/x/sync v0.8.0
|
||||
golang.org/x/sys v0.24.0
|
||||
golang.org/x/term v0.23.0
|
||||
golang.org/x/net v0.37.0
|
||||
golang.org/x/sync v0.12.0
|
||||
golang.org/x/sys v0.31.0
|
||||
golang.org/x/term v0.30.0
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
google.golang.org/protobuf v1.34.2
|
||||
google.golang.org/protobuf v1.36.5
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/bits-and-blooms/bitset v1.13.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/klauspost/compress v1.17.11 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.5.0 // indirect
|
||||
github.com/prometheus/common v0.48.0 // indirect
|
||||
github.com/prometheus/procfs v0.12.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.62.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
golang.org/x/mod v0.18.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
|
||||
75
go.sum
75
go.sum
@ -14,11 +14,9 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
|
||||
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE=
|
||||
github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
|
||||
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps=
|
||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@ -26,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
||||
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||
github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc=
|
||||
github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
|
||||
github.com/gaissmai/bart v0.18.1 h1:bX2j560JC1MJpoEDevBGvXL5OZ1mkls320Vl8Igb5QQ=
|
||||
github.com/gaissmai/bart v0.18.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY=
|
||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
||||
@ -70,6 +68,8 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX
|
||||
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
|
||||
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
|
||||
@ -80,15 +80,19 @@ github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3x
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||
github.com/miekg/dns v1.1.61 h1:nLxbwF3XxhwVSm8g9Dghm9MHPaUZuqhPiGL+675ZmEs=
|
||||
github.com/miekg/dns v1.1.61/go.mod h1:mnAarhS3nWaW+NVP2wTkYVIZyHNJ098SJZUki3eykwQ=
|
||||
github.com/miekg/dns v1.1.63 h1:8M5aAw6OMZfFXTT7K5V0Eu5YiiL8l7nUAkyN6C9YwaY=
|
||||
github.com/miekg/dns v1.1.63/go.mod h1:6NGHfjhpmr5lt3XPLuyfDJi5AXbNIPM9PY6H6sF1Nfs=
|
||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
|
||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f h1:8dM0ilqKL0Uzl42GABzzC4Oqlc3kGRILz0vgoff7nwg=
|
||||
@ -102,24 +106,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
|
||||
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
||||
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
|
||||
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
|
||||
github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE=
|
||||
github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho=
|
||||
github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
|
||||
github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
|
||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw=
|
||||
github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI=
|
||||
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
||||
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
||||
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
|
||||
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
|
||||
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
|
||||
github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE=
|
||||
github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc=
|
||||
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
|
||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
|
||||
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
||||
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
||||
github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
|
||||
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
@ -131,8 +135,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
|
||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
|
||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
|
||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
@ -141,11 +143,10 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs=
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
|
||||
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
|
||||
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
@ -155,8 +156,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
|
||||
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
@ -175,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
|
||||
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
|
||||
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
||||
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@ -184,30 +185,30 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
|
||||
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
|
||||
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU=
|
||||
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
|
||||
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
|
||||
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
@ -238,8 +239,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
|
||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
|
||||
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
|
||||
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
|
||||
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
452
handshake_ix.go
452
handshake_ix.go
@ -2,10 +2,12 @@ package nebula
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/header"
|
||||
)
|
||||
|
||||
@ -16,30 +18,60 @@ import (
|
||||
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
err := f.handshakeManager.allocateIndex(hh)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
|
||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
|
||||
return false
|
||||
}
|
||||
|
||||
certState := f.pki.GetCertState()
|
||||
ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
|
||||
// If we're connecting to a v6 address we must use a v2 cert
|
||||
cs := f.pki.getCertState()
|
||||
v := cs.defaultVersion
|
||||
for _, a := range hh.hostinfo.vpnAddrs {
|
||||
if a.Is6() {
|
||||
v = cert.Version2
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
crt := cs.getCertificate(v)
|
||||
if crt == nil {
|
||||
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", v).
|
||||
Error("Unable to handshake with host because no certificate is available")
|
||||
return false
|
||||
}
|
||||
|
||||
crtHs := cs.getHandshakeBytes(v)
|
||||
if crtHs == nil {
|
||||
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", v).
|
||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", v).
|
||||
Error("Failed to create connection state")
|
||||
return false
|
||||
}
|
||||
hh.hostinfo.ConnectionState = ci
|
||||
|
||||
hsProto := &NebulaHandshakeDetails{
|
||||
InitiatorIndex: hh.hostinfo.localIndexId,
|
||||
Time: uint64(time.Now().UnixNano()),
|
||||
Cert: certState.RawCertificateNoKey,
|
||||
}
|
||||
|
||||
hsBytes := []byte{}
|
||||
|
||||
hs := &NebulaHandshake{
|
||||
Details: hsProto,
|
||||
Details: &NebulaHandshakeDetails{
|
||||
InitiatorIndex: hh.hostinfo.localIndexId,
|
||||
Time: uint64(time.Now().UnixNano()),
|
||||
Cert: crtHs,
|
||||
CertVersion: uint32(v),
|
||||
},
|
||||
}
|
||||
hsBytes, err = hs.Marshal()
|
||||
|
||||
hsBytes, err := hs.Marshal()
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
|
||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||
return false
|
||||
}
|
||||
@ -48,7 +80,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
|
||||
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
|
||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||
return false
|
||||
}
|
||||
@ -63,79 +95,141 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
}
|
||||
|
||||
func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
|
||||
certState := f.pki.GetCertState()
|
||||
ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
|
||||
cs := f.pki.getCertState()
|
||||
crt := cs.GetDefaultCertificate()
|
||||
if crt == nil {
|
||||
f.l.WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", cs.defaultVersion).
|
||||
Error("Unable to handshake with host because no certificate is available")
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Error("Failed to create connection state")
|
||||
return
|
||||
}
|
||||
|
||||
// Mark packet 1 as seen so it doesn't show up as missed
|
||||
ci.window.Update(f.l, 1)
|
||||
|
||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Error("Failed to call noise.ReadMessage")
|
||||
return
|
||||
}
|
||||
|
||||
hs := &NebulaHandshake{}
|
||||
err = hs.Unmarshal(msg)
|
||||
/*
|
||||
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
|
||||
*/
|
||||
if err != nil || hs.Details == nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Error("Failed unmarshal handshake message")
|
||||
return
|
||||
}
|
||||
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
|
||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||
if err != nil {
|
||||
e := f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Info("Handshake did not contain a certificate")
|
||||
return
|
||||
}
|
||||
|
||||
if f.l.Level > logrus.DebugLevel {
|
||||
e = e.WithField("cert", remoteCert)
|
||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
||||
if err != nil {
|
||||
fp, err := rc.Fingerprint()
|
||||
if err != nil {
|
||||
fp = "<error generating certificate fingerprint>"
|
||||
}
|
||||
|
||||
e := f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithField("certVpnNetworks", rc.Networks()).
|
||||
WithField("certFingerprint", fp)
|
||||
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
e = e.WithField("cert", rc)
|
||||
}
|
||||
|
||||
e.Info("Invalid certificate from host")
|
||||
return
|
||||
}
|
||||
|
||||
vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
|
||||
if !ok {
|
||||
e := f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
|
||||
|
||||
if f.l.Level > logrus.DebugLevel {
|
||||
e = e.WithField("cert", remoteCert)
|
||||
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||
rc := cs.getCertificate(remoteCert.Certificate.Version())
|
||||
if rc == nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||
Info("Unable to handshake with host due to missing certificate version")
|
||||
return
|
||||
}
|
||||
|
||||
e.Info("Invalid vpn ip from host")
|
||||
// Record the certificate we are actually using
|
||||
ci.myCert = rc
|
||||
}
|
||||
|
||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("cert", remoteCert).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Info("No networks in certificate")
|
||||
return
|
||||
}
|
||||
|
||||
vpnIp = vpnIp.Unmap()
|
||||
certName := remoteCert.Details.Name
|
||||
fingerprint, _ := remoteCert.Sha256Sum()
|
||||
issuer := remoteCert.Details.Issuer
|
||||
var vpnAddrs []netip.Addr
|
||||
var filteredNetworks []netip.Prefix
|
||||
certName := remoteCert.Certificate.Name()
|
||||
fingerprint := remoteCert.Fingerprint
|
||||
issuer := remoteCert.Certificate.Issuer()
|
||||
|
||||
if vpnIp == f.myVpnNet.Addr() {
|
||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||
for _, network := range remoteCert.Certificate.Networks() {
|
||||
vpnAddr := network.Addr()
|
||||
_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
|
||||
if found {
|
||||
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||
return
|
||||
}
|
||||
|
||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
filteredNetworks = append(filteredNetworks, network)
|
||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||
}
|
||||
|
||||
if len(vpnAddrs) == 0 {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||
return
|
||||
}
|
||||
|
||||
if addr.IsValid() {
|
||||
if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) {
|
||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||
// addr can be invalid when the tunnel is being relayed.
|
||||
// We only want to apply the remote allow list for direct tunnels here
|
||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) {
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
myIndex, err := generateIndex(f.l)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
@ -147,17 +241,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
ConnectionState: ci,
|
||||
localIndexId: myIndex,
|
||||
remoteIndexId: hs.Details.InitiatorIndex,
|
||||
vpnIp: vpnIp,
|
||||
vpnAddrs: vpnAddrs,
|
||||
HandshakePacket: make(map[uint8][]byte, 0),
|
||||
lastHandshakeTime: hs.Details.Time,
|
||||
relayState: RelayState{
|
||||
relays: map[netip.Addr]struct{}{},
|
||||
relayForByIp: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
relays: map[netip.Addr]struct{}{},
|
||||
relayForByAddr: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
},
|
||||
}
|
||||
|
||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
@ -166,13 +260,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
Info("Handshake message received")
|
||||
|
||||
hs.Details.ResponderIndex = myIndex
|
||||
hs.Details.Cert = certState.RawCertificateNoKey
|
||||
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
||||
if hs.Details.Cert == nil {
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithField("certVersion", ci.myCert.Version()).
|
||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||
return
|
||||
}
|
||||
|
||||
hs.Details.CertVersion = uint32(ci.myCert.Version())
|
||||
// Update the time in case their clock is way off from ours
|
||||
hs.Details.Time = uint64(time.Now().UnixNano())
|
||||
|
||||
hsBytes, err := hs.Marshal()
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
@ -183,14 +290,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
||||
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||
return
|
||||
} else if dKey == nil || eKey == nil {
|
||||
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
@ -214,9 +321,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
ci.dKey = NewNebulaCipherState(dKey)
|
||||
ci.eKey = NewNebulaCipherState(eKey)
|
||||
|
||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
|
||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
||||
hostinfo.SetRemote(addr)
|
||||
hostinfo.CreateRemoteCIDR(remoteCert)
|
||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
||||
|
||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
||||
if err != nil {
|
||||
@ -226,7 +333,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
if existing.SetRemoteIfPreferred(f.hostMap, addr) {
|
||||
// Send a test packet to ensure the other side has also switched to
|
||||
// the preferred remote
|
||||
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
}
|
||||
|
||||
msg = existing.HandshakePacket[2]
|
||||
@ -234,11 +341,11 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
if addr.IsValid() {
|
||||
err := f.outside.WriteTo(msg, addr)
|
||||
if err != nil {
|
||||
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
WithError(err).Error("Failed to send handshake message")
|
||||
} else {
|
||||
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
Info("Handshake message sent")
|
||||
}
|
||||
@ -248,16 +355,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
f.l.Error("Handshake send failed: both addr and via are nil.")
|
||||
return
|
||||
}
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||
f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp).
|
||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
Info("Handshake message sent")
|
||||
return
|
||||
}
|
||||
case ErrExistingHostInfo:
|
||||
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
||||
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
|
||||
@ -268,23 +375,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
Info("Handshake too old")
|
||||
|
||||
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
||||
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
return
|
||||
case ErrLocalIndexCollision:
|
||||
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
|
||||
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
|
||||
Error("Failed to add HostInfo due to localIndex collision")
|
||||
return
|
||||
default:
|
||||
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
||||
// And we forget to update it here
|
||||
f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
@ -300,7 +407,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
if addr.IsValid() {
|
||||
err = f.outside.WriteTo(msg, addr)
|
||||
if err != nil {
|
||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
@ -308,7 +415,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
WithError(err).Error("Failed to send handshake")
|
||||
} else {
|
||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
@ -321,9 +428,12 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
f.l.Error("Handshake send failed: both addr and via are nil.")
|
||||
return
|
||||
}
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||
// I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure
|
||||
// it's correctly marked as working.
|
||||
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
|
||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||
f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp).
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
@ -350,8 +460,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
|
||||
hostinfo := hh.hostinfo
|
||||
if addr.IsValid() {
|
||||
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) {
|
||||
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
|
||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) {
|
||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||
return false
|
||||
}
|
||||
}
|
||||
@ -359,7 +470,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
ci := hostinfo.ConnectionState
|
||||
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
||||
Error("Failed to call noise.ReadMessage")
|
||||
|
||||
@ -368,7 +479,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
// near future
|
||||
return false
|
||||
} else if dKey == nil || eKey == nil {
|
||||
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
Error("Noise did not arrive at a key")
|
||||
|
||||
@ -380,95 +491,56 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
hs := &NebulaHandshake{}
|
||||
err = hs.Unmarshal(msg)
|
||||
if err != nil || hs.Details == nil {
|
||||
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
||||
|
||||
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
||||
return true
|
||||
}
|
||||
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
|
||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||
if err != nil {
|
||||
e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
|
||||
|
||||
if f.l.Level > logrus.DebugLevel {
|
||||
e = e.WithField("cert", remoteCert)
|
||||
}
|
||||
|
||||
e.Error("Invalid certificate from host")
|
||||
|
||||
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
||||
return true
|
||||
}
|
||||
|
||||
vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
|
||||
if !ok {
|
||||
e := f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
|
||||
|
||||
if f.l.Level > logrus.DebugLevel {
|
||||
e = e.WithField("cert", remoteCert)
|
||||
}
|
||||
|
||||
e.Info("Invalid vpn ip from host")
|
||||
return true
|
||||
}
|
||||
|
||||
vpnIp = vpnIp.Unmap()
|
||||
certName := remoteCert.Details.Name
|
||||
fingerprint, _ := remoteCert.Sha256Sum()
|
||||
issuer := remoteCert.Details.Issuer
|
||||
|
||||
// Ensure the right host responded
|
||||
if vpnIp != hostinfo.vpnIp {
|
||||
f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp).
|
||||
WithField("udpAddr", addr).WithField("certName", certName).
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
Info("Incorrect host responded to handshake")
|
||||
|
||||
// Release our old handshake from pending, it should not continue
|
||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||
|
||||
// Create a new hostinfo/handshake for the intended vpn ip
|
||||
f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) {
|
||||
//TODO: this doesnt know if its being added or is being used for caching a packet
|
||||
// Block the current used address
|
||||
newHH.hostinfo.remotes = hostinfo.remotes
|
||||
newHH.hostinfo.remotes.BlockRemote(addr)
|
||||
|
||||
// Get the correct remote list for the host we did handshake with
|
||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
|
||||
|
||||
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
|
||||
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
|
||||
Info("Blocked addresses for handshakes")
|
||||
|
||||
// Swap the packet store to benefit the original intended recipient
|
||||
newHH.packetStore = hh.packetStore
|
||||
hh.packetStore = []*cachedPacket{}
|
||||
|
||||
// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
|
||||
hostinfo.vpnIp = vpnIp
|
||||
f.sendCloseTunnel(hostinfo)
|
||||
})
|
||||
|
||||
Info("Handshake did not contain a certificate")
|
||||
return true
|
||||
}
|
||||
|
||||
// Mark packet 2 as seen so it doesn't show up as missed
|
||||
ci.window.Update(f.l, 2)
|
||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
||||
if err != nil {
|
||||
fp, err := rc.Fingerprint()
|
||||
if err != nil {
|
||||
fp = "<error generating certificate fingerprint>"
|
||||
}
|
||||
|
||||
duration := time.Since(hh.startTime).Nanoseconds()
|
||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
WithField("durationNs", duration).
|
||||
WithField("sentCachedPackets", len(hh.packetStore)).
|
||||
Info("Handshake message received")
|
||||
e := f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
WithField("certFingerprint", fp).
|
||||
WithField("certVpnNetworks", rc.Networks())
|
||||
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
e = e.WithField("cert", rc)
|
||||
}
|
||||
|
||||
e.Info("Invalid certificate from host")
|
||||
return true
|
||||
}
|
||||
|
||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("cert", remoteCert).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
Info("No networks in certificate")
|
||||
return true
|
||||
}
|
||||
|
||||
vpnNetworks := remoteCert.Certificate.Networks()
|
||||
certName := remoteCert.Certificate.Name()
|
||||
fingerprint := remoteCert.Fingerprint
|
||||
issuer := remoteCert.Certificate.Issuer()
|
||||
|
||||
hostinfo.remoteIndexId = hs.Details.ResponderIndex
|
||||
hostinfo.lastHandshakeTime = hs.Details.Time
|
||||
@ -482,13 +554,83 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
if addr.IsValid() {
|
||||
hostinfo.SetRemote(addr)
|
||||
} else {
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||
}
|
||||
|
||||
// Build up the radix for the firewall if we have subnets in the cert
|
||||
hostinfo.CreateRemoteCIDR(remoteCert)
|
||||
var vpnAddrs []netip.Addr
|
||||
var filteredNetworks []netip.Prefix
|
||||
for _, network := range vpnNetworks {
|
||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||
vpnAddr := network.Addr()
|
||||
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
|
||||
filteredNetworks = append(filteredNetworks, network)
|
||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||
}
|
||||
|
||||
if len(vpnAddrs) == 0 {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||
return true
|
||||
}
|
||||
|
||||
// Ensure the right host responded
|
||||
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
||||
WithField("udpAddr", addr).WithField("certName", certName).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
Info("Incorrect host responded to handshake")
|
||||
|
||||
// Release our old handshake from pending, it should not continue
|
||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||
|
||||
// Create a new hostinfo/handshake for the intended vpn ip
|
||||
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
||||
// Block the current used address
|
||||
newHH.hostinfo.remotes = hostinfo.remotes
|
||||
newHH.hostinfo.remotes.BlockRemote(addr)
|
||||
|
||||
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
|
||||
WithField("vpnNetworks", vpnNetworks).
|
||||
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
|
||||
Info("Blocked addresses for handshakes")
|
||||
|
||||
// Swap the packet store to benefit the original intended recipient
|
||||
newHH.packetStore = hh.packetStore
|
||||
hh.packetStore = []*cachedPacket{}
|
||||
|
||||
// Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down
|
||||
hostinfo.vpnAddrs = vpnAddrs
|
||||
f.sendCloseTunnel(hostinfo)
|
||||
})
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Mark packet 2 as seen so it doesn't show up as missed
|
||||
ci.window.Update(f.l, 2)
|
||||
|
||||
duration := time.Since(hh.startTime).Nanoseconds()
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
WithField("durationNs", duration).
|
||||
WithField("sentCachedPackets", len(hh.packetStore)).
|
||||
Info("Handshake message received")
|
||||
|
||||
// Build up the radix for the firewall if we have subnets in the cert
|
||||
hostinfo.vpnAddrs = vpnAddrs
|
||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
||||
|
||||
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
||||
f.handshakeManager.Complete(hostinfo, f)
|
||||
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
|
||||
|
||||
|
||||
@ -7,14 +7,15 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -118,18 +119,18 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) Run(ctx context.Context) {
|
||||
clockSource := time.NewTicker(c.config.tryInterval)
|
||||
func (hm *HandshakeManager) Run(ctx context.Context) {
|
||||
clockSource := time.NewTicker(hm.config.tryInterval)
|
||||
defer clockSource.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case vpnIP := <-c.trigger:
|
||||
c.handleOutbound(vpnIP, true)
|
||||
case vpnIP := <-hm.trigger:
|
||||
hm.handleOutbound(vpnIP, true)
|
||||
case now := <-clockSource.C:
|
||||
c.NextOutboundHandshakeTimerTick(now)
|
||||
hm.NextOutboundHandshakeTimerTick(now)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -137,7 +138,7 @@ func (c *HandshakeManager) Run(ctx context.Context) {
|
||||
func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
|
||||
// First remote allow list check before we know the vpnIp
|
||||
if addr.IsValid() {
|
||||
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) {
|
||||
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) {
|
||||
hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||
return
|
||||
}
|
||||
@ -159,14 +160,14 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
|
||||
c.OutboundHandshakeTimer.Advance(now)
|
||||
func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
|
||||
hm.OutboundHandshakeTimer.Advance(now)
|
||||
for {
|
||||
vpnIp, has := c.OutboundHandshakeTimer.Purge()
|
||||
vpnIp, has := hm.OutboundHandshakeTimer.Purge()
|
||||
if !has {
|
||||
break
|
||||
}
|
||||
c.handleOutbound(vpnIp, false)
|
||||
hm.handleOutbound(vpnIp, false)
|
||||
}
|
||||
}
|
||||
|
||||
@ -208,7 +209,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
// NB ^ This comment doesn't jive. It's how the thing gets initialized.
|
||||
// It's the common path. Should it update every time, in case a future LH query/queries give us more info?
|
||||
if hostinfo.remotes == nil {
|
||||
hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
|
||||
hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp})
|
||||
}
|
||||
|
||||
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
|
||||
@ -223,7 +224,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
|
||||
hh.lastRemotes = remotes
|
||||
|
||||
// TODO: this will generate a load of queries for hosts with only 1 ip
|
||||
// This will generate a load of queries for hosts with only 1 ip
|
||||
// (such as ones registered to the lighthouse with only a private IP)
|
||||
// So we only do it one time after attempting 5 handshakes already.
|
||||
if len(remotes) <= 1 && hh.counter == 5 {
|
||||
@ -256,7 +257,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Info("Handshake message sent")
|
||||
} else if hm.l.IsLevelEnabled(logrus.DebugLevel) {
|
||||
} else if hm.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
|
||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
@ -267,59 +268,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
|
||||
// Send a RelayRequest to all known Relay IP's
|
||||
for _, relay := range hostinfo.remotes.relays {
|
||||
// Don't relay to myself, and don't relay through the host I'm trying to connect to
|
||||
if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() {
|
||||
// Don't relay to myself
|
||||
if relay == vpnIp {
|
||||
continue
|
||||
}
|
||||
relayHostInfo := hm.mainHostMap.QueryVpnIp(relay)
|
||||
|
||||
// Don't relay through the host I'm trying to connect to
|
||||
_, found := hm.f.myVpnAddrsTable.Lookup(relay)
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
|
||||
relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
|
||||
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
|
||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
|
||||
hm.f.Handshake(relay)
|
||||
continue
|
||||
}
|
||||
// Check the relay HostInfo to see if we already established a relay through it
|
||||
if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok {
|
||||
switch existingRelay.State {
|
||||
case Established:
|
||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
|
||||
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
|
||||
case Requested:
|
||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
|
||||
|
||||
//TODO: IPV6-WORK
|
||||
myVpnIpB := hm.f.myVpnNet.Addr().As4()
|
||||
theirVpnIpB := vpnIp.As4()
|
||||
|
||||
// Re-send the CreateRelay request, in case the previous one was lost.
|
||||
m := NebulaControl{
|
||||
Type: NebulaControl_CreateRelayRequest,
|
||||
InitiatorRelayIndex: existingRelay.LocalIndex,
|
||||
RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
|
||||
RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
|
||||
}
|
||||
msg, err := m.Marshal()
|
||||
if err != nil {
|
||||
hostinfo.logger(hm.l).
|
||||
WithError(err).
|
||||
Error("Failed to marshal Control message to create relay")
|
||||
} else {
|
||||
// This must send over the hostinfo, not over hm.Hosts[ip]
|
||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
hm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": hm.f.myVpnNet.Addr(),
|
||||
"relayTo": vpnIp,
|
||||
"initiatorRelayIndex": existingRelay.LocalIndex,
|
||||
"relay": relay}).
|
||||
Info("send CreateRelayRequest")
|
||||
}
|
||||
default:
|
||||
hostinfo.logger(hm.l).
|
||||
WithField("vpnIp", vpnIp).
|
||||
WithField("state", existingRelay.State).
|
||||
WithField("relay", relayHostInfo.vpnIp).
|
||||
Errorf("Relay unexpected state")
|
||||
}
|
||||
} else {
|
||||
// Check the relay HostInfo to see if we already established a relay through
|
||||
existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp)
|
||||
if !ok {
|
||||
// No relays exist or requested yet.
|
||||
if relayHostInfo.remote.IsValid() {
|
||||
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
|
||||
@ -327,16 +295,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
|
||||
}
|
||||
|
||||
//TODO: IPV6-WORK
|
||||
myVpnIpB := hm.f.myVpnNet.Addr().As4()
|
||||
theirVpnIpB := vpnIp.As4()
|
||||
|
||||
m := NebulaControl{
|
||||
Type: NebulaControl_CreateRelayRequest,
|
||||
InitiatorRelayIndex: idx,
|
||||
RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
|
||||
RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
|
||||
}
|
||||
|
||||
switch relayHostInfo.GetCert().Certificate.Version() {
|
||||
case cert.Version1:
|
||||
if !hm.f.myVpnAddrs[0].Is4() {
|
||||
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||
continue
|
||||
}
|
||||
|
||||
if !vpnIp.Is4() {
|
||||
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||
continue
|
||||
}
|
||||
|
||||
b := hm.f.myVpnAddrs[0].As4()
|
||||
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||
b = vpnIp.As4()
|
||||
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||
case cert.Version2:
|
||||
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
|
||||
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
|
||||
default:
|
||||
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
|
||||
continue
|
||||
}
|
||||
|
||||
msg, err := m.Marshal()
|
||||
if err != nil {
|
||||
hostinfo.logger(hm.l).
|
||||
@ -345,13 +332,80 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
} else {
|
||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
hm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": hm.f.myVpnNet.Addr(),
|
||||
"relayFrom": hm.f.myVpnAddrs[0],
|
||||
"relayTo": vpnIp,
|
||||
"initiatorRelayIndex": idx,
|
||||
"relay": relay}).
|
||||
Info("send CreateRelayRequest")
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch existingRelay.State {
|
||||
case Established:
|
||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
|
||||
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
|
||||
case Disestablished:
|
||||
// Mark this relay as 'requested'
|
||||
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
|
||||
fallthrough
|
||||
case Requested:
|
||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
|
||||
// Re-send the CreateRelay request, in case the previous one was lost.
|
||||
m := NebulaControl{
|
||||
Type: NebulaControl_CreateRelayRequest,
|
||||
InitiatorRelayIndex: existingRelay.LocalIndex,
|
||||
}
|
||||
|
||||
switch relayHostInfo.GetCert().Certificate.Version() {
|
||||
case cert.Version1:
|
||||
if !hm.f.myVpnAddrs[0].Is4() {
|
||||
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||
continue
|
||||
}
|
||||
|
||||
if !vpnIp.Is4() {
|
||||
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||
continue
|
||||
}
|
||||
|
||||
b := hm.f.myVpnAddrs[0].As4()
|
||||
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||
b = vpnIp.As4()
|
||||
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||
case cert.Version2:
|
||||
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
|
||||
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
|
||||
default:
|
||||
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
|
||||
continue
|
||||
}
|
||||
msg, err := m.Marshal()
|
||||
if err != nil {
|
||||
hostinfo.logger(hm.l).
|
||||
WithError(err).
|
||||
Error("Failed to marshal Control message to create relay")
|
||||
} else {
|
||||
// This must send over the hostinfo, not over hm.Hosts[ip]
|
||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
hm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": hm.f.myVpnAddrs[0],
|
||||
"relayTo": vpnIp,
|
||||
"initiatorRelayIndex": existingRelay.LocalIndex,
|
||||
"relay": relay}).
|
||||
Info("send CreateRelayRequest")
|
||||
}
|
||||
case PeerRequested:
|
||||
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
|
||||
fallthrough
|
||||
default:
|
||||
hostinfo.logger(hm.l).
|
||||
WithField("vpnIp", vpnIp).
|
||||
WithField("state", existingRelay.State).
|
||||
WithField("relay", relay).
|
||||
Errorf("Relay unexpected state")
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -381,10 +435,10 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*Hands
|
||||
}
|
||||
|
||||
// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
|
||||
func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
|
||||
func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
|
||||
hm.Lock()
|
||||
|
||||
if hh, ok := hm.vpnIps[vpnIp]; ok {
|
||||
if hh, ok := hm.vpnIps[vpnAddr]; ok {
|
||||
// We are already trying to handshake with this vpn ip
|
||||
if cacheCb != nil {
|
||||
cacheCb(hh)
|
||||
@ -394,12 +448,12 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
|
||||
}
|
||||
|
||||
hostinfo := &HostInfo{
|
||||
vpnIp: vpnIp,
|
||||
vpnAddrs: []netip.Addr{vpnAddr},
|
||||
HandshakePacket: make(map[uint8][]byte, 0),
|
||||
relayState: RelayState{
|
||||
relays: map[netip.Addr]struct{}{},
|
||||
relayForByIp: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
relays: map[netip.Addr]struct{}{},
|
||||
relayForByAddr: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
},
|
||||
}
|
||||
|
||||
@ -407,9 +461,9 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
|
||||
hostinfo: hostinfo,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
hm.vpnIps[vpnIp] = hh
|
||||
hm.vpnIps[vpnAddr] = hh
|
||||
hm.metricInitiated.Inc(1)
|
||||
hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval)
|
||||
hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval)
|
||||
|
||||
if cacheCb != nil {
|
||||
cacheCb(hh)
|
||||
@ -417,21 +471,21 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
|
||||
|
||||
// If this is a static host, we don't need to wait for the HostQueryReply
|
||||
// We can trigger the handshake right now
|
||||
_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp]
|
||||
_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr]
|
||||
if !doTrigger {
|
||||
// Add any calculated remotes, and trigger early handshake if one found
|
||||
doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp)
|
||||
doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr)
|
||||
}
|
||||
|
||||
if doTrigger {
|
||||
select {
|
||||
case hm.trigger <- vpnIp:
|
||||
case hm.trigger <- vpnAddr:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
hm.Unlock()
|
||||
hm.lightHouse.QueryServer(vpnIp)
|
||||
hm.lightHouse.QueryServer(vpnAddr)
|
||||
return hostinfo
|
||||
}
|
||||
|
||||
@ -452,14 +506,14 @@ var (
|
||||
//
|
||||
// ErrLocalIndexCollision if we already have an entry in the main or pending
|
||||
// hostmap for the hostinfo.localIndexId.
|
||||
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
|
||||
c.mainHostMap.Lock()
|
||||
defer c.mainHostMap.Unlock()
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
|
||||
hm.mainHostMap.Lock()
|
||||
defer hm.mainHostMap.Unlock()
|
||||
hm.Lock()
|
||||
defer hm.Unlock()
|
||||
|
||||
// Check if we already have a tunnel with this vpn ip
|
||||
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
|
||||
existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]]
|
||||
if found && existingHostInfo != nil {
|
||||
testHostInfo := existingHostInfo
|
||||
for testHostInfo != nil {
|
||||
@ -476,31 +530,31 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
|
||||
return existingHostInfo, ErrExistingHostInfo
|
||||
}
|
||||
|
||||
existingHostInfo.logger(c.l).Info("Taking new handshake")
|
||||
existingHostInfo.logger(hm.l).Info("Taking new handshake")
|
||||
}
|
||||
|
||||
existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
|
||||
existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId]
|
||||
if found {
|
||||
// We have a collision, but for a different hostinfo
|
||||
return existingIndex, ErrLocalIndexCollision
|
||||
}
|
||||
|
||||
existingPendingIndex, found := c.indexes[hostinfo.localIndexId]
|
||||
existingPendingIndex, found := hm.indexes[hostinfo.localIndexId]
|
||||
if found && existingPendingIndex.hostinfo != hostinfo {
|
||||
// We have a collision, but for a different hostinfo
|
||||
return existingPendingIndex.hostinfo, ErrLocalIndexCollision
|
||||
}
|
||||
|
||||
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
|
||||
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp {
|
||||
existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
|
||||
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
|
||||
// We have a collision, but this can happen since we can't control
|
||||
// the remote ID. Just log about the situation as a note.
|
||||
hostinfo.logger(c.l).
|
||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
|
||||
hostinfo.logger(hm.l).
|
||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
|
||||
Info("New host shadows existing host remoteIndex")
|
||||
}
|
||||
|
||||
c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
|
||||
hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
|
||||
return existingHostInfo, nil
|
||||
}
|
||||
|
||||
@ -518,7 +572,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
|
||||
// We have a collision, but this can happen since we can't control
|
||||
// the remote ID. Just log about the situation as a note.
|
||||
hostinfo.logger(hm.l).
|
||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
|
||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
|
||||
Info("New host shadows existing host remoteIndex")
|
||||
}
|
||||
|
||||
@ -555,31 +609,34 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
|
||||
return errors.New("failed to generate unique localIndexId")
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.unlockedDeleteHostInfo(hostinfo)
|
||||
func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
|
||||
hm.Lock()
|
||||
defer hm.Unlock()
|
||||
hm.unlockedDeleteHostInfo(hostinfo)
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
||||
delete(c.vpnIps, hostinfo.vpnIp)
|
||||
if len(c.vpnIps) == 0 {
|
||||
c.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
|
||||
func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
||||
for _, addr := range hostinfo.vpnAddrs {
|
||||
delete(hm.vpnIps, addr)
|
||||
}
|
||||
|
||||
delete(c.indexes, hostinfo.localIndexId)
|
||||
if len(c.vpnIps) == 0 {
|
||||
c.indexes = map[uint32]*HandshakeHostInfo{}
|
||||
if len(hm.vpnIps) == 0 {
|
||||
hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
|
||||
}
|
||||
|
||||
if c.l.Level >= logrus.DebugLevel {
|
||||
c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps),
|
||||
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
||||
delete(hm.indexes, hostinfo.localIndexId)
|
||||
if len(hm.indexes) == 0 {
|
||||
hm.indexes = map[uint32]*HandshakeHostInfo{}
|
||||
}
|
||||
|
||||
if hm.l.Level >= logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps),
|
||||
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
||||
Debug("Pending hostmap hostInfo deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
|
||||
func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
|
||||
hh := hm.queryVpnIp(vpnIp)
|
||||
if hh != nil {
|
||||
return hh.hostinfo
|
||||
@ -608,37 +665,37 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
|
||||
return hm.indexes[index]
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix {
|
||||
return c.mainHostMap.GetPreferredRanges()
|
||||
func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix {
|
||||
return hm.mainHostMap.GetPreferredRanges()
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) {
|
||||
hm.RLock()
|
||||
defer hm.RUnlock()
|
||||
|
||||
for _, v := range c.vpnIps {
|
||||
for _, v := range hm.vpnIps {
|
||||
f(v.hostinfo)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) ForEachIndex(f controlEach) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
func (hm *HandshakeManager) ForEachIndex(f controlEach) {
|
||||
hm.RLock()
|
||||
defer hm.RUnlock()
|
||||
|
||||
for _, v := range c.indexes {
|
||||
for _, v := range hm.indexes {
|
||||
f(v.hostinfo)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) EmitStats() {
|
||||
c.RLock()
|
||||
hostLen := len(c.vpnIps)
|
||||
indexLen := len(c.indexes)
|
||||
c.RUnlock()
|
||||
func (hm *HandshakeManager) EmitStats() {
|
||||
hm.RLock()
|
||||
hostLen := len(hm.vpnIps)
|
||||
indexLen := len(hm.indexes)
|
||||
hm.RUnlock()
|
||||
|
||||
metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
|
||||
metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
|
||||
c.mainHostMap.EmitStats()
|
||||
hm.mainHostMap.EmitStats()
|
||||
}
|
||||
|
||||
// Utility functions below
|
||||
|
||||
@ -14,21 +14,20 @@ import (
|
||||
|
||||
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
|
||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
||||
ip := netip.MustParseAddr("172.1.1.2")
|
||||
|
||||
preferredRanges := []netip.Prefix{localrange}
|
||||
mainHM := newHostMap(l, vpncidr)
|
||||
mainHM := newHostMap(l)
|
||||
mainHM.preferredRanges.Store(&preferredRanges)
|
||||
|
||||
lh := newTestLighthouse()
|
||||
|
||||
cs := &CertState{
|
||||
RawCertificate: []byte{},
|
||||
PrivateKey: []byte{},
|
||||
Certificate: &cert.NebulaCertificate{},
|
||||
RawCertificateNoKey: []byte{},
|
||||
defaultVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1},
|
||||
v1HandshakeBytes: []byte{},
|
||||
}
|
||||
|
||||
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
|
||||
@ -42,10 +41,10 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
||||
i2 := blah.StartHandshake(ip, nil)
|
||||
assert.Same(t, i, i2)
|
||||
|
||||
i.remotes = NewRemoteList(nil)
|
||||
i.remotes = NewRemoteList([]netip.Addr{}, nil)
|
||||
|
||||
// Adding something to pending should not affect the main hostmap
|
||||
assert.Len(t, mainHM.Hosts, 0)
|
||||
assert.Empty(t, mainHM.Hosts)
|
||||
|
||||
// Confirm they are in the pending index list
|
||||
assert.Contains(t, blah.vpnIps, ip)
|
||||
@ -80,16 +79,24 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
|
||||
type mockEncWriter struct {
|
||||
}
|
||||
|
||||
func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
|
||||
func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
|
||||
func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
|
||||
return
|
||||
}
|
||||
|
||||
func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
|
||||
func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {}
|
||||
func (mw *mockEncWriter) Handshake(_ netip.Addr) {}
|
||||
|
||||
func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mw *mockEncWriter) GetCertState() *CertState {
|
||||
return &CertState{defaultVersion: cert.Version2}
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type headerTest struct {
|
||||
@ -111,7 +112,7 @@ func TestHeader_String(t *testing.T) {
|
||||
|
||||
func TestHeader_MarshalJSON(t *testing.T) {
|
||||
b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON()
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",
|
||||
|
||||
275
hostmap.go
275
hostmap.go
@ -35,6 +35,7 @@ const (
|
||||
Requested = iota
|
||||
PeerRequested
|
||||
Established
|
||||
Disestablished
|
||||
)
|
||||
|
||||
const (
|
||||
@ -48,7 +49,7 @@ type Relay struct {
|
||||
State int
|
||||
LocalIndex uint32
|
||||
RemoteIndex uint32
|
||||
PeerIp netip.Addr
|
||||
PeerAddr netip.Addr
|
||||
}
|
||||
|
||||
type HostMap struct {
|
||||
@ -58,7 +59,6 @@ type HostMap struct {
|
||||
RemoteIndexes map[uint32]*HostInfo
|
||||
Hosts map[netip.Addr]*HostInfo
|
||||
preferredRanges atomic.Pointer[[]netip.Prefix]
|
||||
vpnCIDR netip.Prefix
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
@ -68,9 +68,12 @@ type HostMap struct {
|
||||
type RelayState struct {
|
||||
sync.RWMutex
|
||||
|
||||
relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
|
||||
relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
|
||||
relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info
|
||||
relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
|
||||
// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
|
||||
// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
|
||||
// the RelayState Lock held)
|
||||
relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info
|
||||
relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info
|
||||
}
|
||||
|
||||
func (rs *RelayState) DeleteRelay(ip netip.Addr) {
|
||||
@ -79,6 +82,28 @@ func (rs *RelayState) DeleteRelay(ip netip.Addr) {
|
||||
delete(rs.relays, ip)
|
||||
}
|
||||
|
||||
func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
|
||||
rs.Lock()
|
||||
defer rs.Unlock()
|
||||
if r, ok := rs.relayForByAddr[vpnIp]; ok {
|
||||
newRelay := *r
|
||||
newRelay.State = state
|
||||
rs.relayForByAddr[newRelay.PeerAddr] = &newRelay
|
||||
rs.relayForByIdx[newRelay.LocalIndex] = &newRelay
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *RelayState) UpdateRelayForByIdxState(idx uint32, state int) {
|
||||
rs.Lock()
|
||||
defer rs.Unlock()
|
||||
if r, ok := rs.relayForByIdx[idx]; ok {
|
||||
newRelay := *r
|
||||
newRelay.State = state
|
||||
rs.relayForByAddr[newRelay.PeerAddr] = &newRelay
|
||||
rs.relayForByIdx[newRelay.LocalIndex] = &newRelay
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *RelayState) CopyAllRelayFor() []*Relay {
|
||||
rs.RLock()
|
||||
defer rs.RUnlock()
|
||||
@ -89,10 +114,10 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
|
||||
return ret
|
||||
}
|
||||
|
||||
func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) {
|
||||
func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
|
||||
rs.RLock()
|
||||
defer rs.RUnlock()
|
||||
r, ok := rs.relayForByIp[ip]
|
||||
r, ok := rs.relayForByAddr[addr]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
@ -115,8 +140,8 @@ func (rs *RelayState) CopyRelayIps() []netip.Addr {
|
||||
func (rs *RelayState) CopyRelayForIps() []netip.Addr {
|
||||
rs.RLock()
|
||||
defer rs.RUnlock()
|
||||
currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp))
|
||||
for relayIp := range rs.relayForByIp {
|
||||
currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr))
|
||||
for relayIp := range rs.relayForByAddr {
|
||||
currentRelays = append(currentRelays, relayIp)
|
||||
}
|
||||
return currentRelays
|
||||
@ -135,7 +160,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
|
||||
func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
|
||||
rs.Lock()
|
||||
defer rs.Unlock()
|
||||
r, ok := rs.relayForByIp[vpnIp]
|
||||
r, ok := rs.relayForByAddr[vpnIp]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
@ -143,7 +168,7 @@ func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool
|
||||
newRelay.State = Established
|
||||
newRelay.RemoteIndex = remoteIdx
|
||||
rs.relayForByIdx[r.LocalIndex] = &newRelay
|
||||
rs.relayForByIp[r.PeerIp] = &newRelay
|
||||
rs.relayForByAddr[r.PeerAddr] = &newRelay
|
||||
return true
|
||||
}
|
||||
|
||||
@ -158,14 +183,14 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
|
||||
newRelay.State = Established
|
||||
newRelay.RemoteIndex = remoteIdx
|
||||
rs.relayForByIdx[r.LocalIndex] = &newRelay
|
||||
rs.relayForByIp[r.PeerIp] = &newRelay
|
||||
rs.relayForByAddr[r.PeerAddr] = &newRelay
|
||||
return &newRelay, true
|
||||
}
|
||||
|
||||
func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
|
||||
rs.RLock()
|
||||
defer rs.RUnlock()
|
||||
r, ok := rs.relayForByIp[vpnIp]
|
||||
r, ok := rs.relayForByAddr[vpnIp]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
@ -179,7 +204,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
|
||||
func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
|
||||
rs.Lock()
|
||||
defer rs.Unlock()
|
||||
rs.relayForByIp[ip] = r
|
||||
rs.relayForByAddr[ip] = r
|
||||
rs.relayForByIdx[idx] = r
|
||||
}
|
||||
|
||||
@ -190,10 +215,16 @@ type HostInfo struct {
|
||||
ConnectionState *ConnectionState
|
||||
remoteIndexId uint32
|
||||
localIndexId uint32
|
||||
vpnIp netip.Addr
|
||||
recvError atomic.Uint32
|
||||
remoteCidr *bart.Table[struct{}]
|
||||
relayState RelayState
|
||||
|
||||
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
|
||||
// The host may have other vpn addresses that are outside our
|
||||
// vpn networks but were removed because they are not usable
|
||||
vpnAddrs []netip.Addr
|
||||
recvError atomic.Uint32
|
||||
|
||||
// networks are both all vpn and unsafe networks assigned to this host
|
||||
networks *bart.Table[struct{}]
|
||||
relayState RelayState
|
||||
|
||||
// HandshakePacket records the packets used to create this hostinfo
|
||||
// We need these to avoid replayed handshake packets creating new hostinfos which causes churn
|
||||
@ -241,28 +272,26 @@ type cachedPacketMetrics struct {
|
||||
dropped metrics.Counter
|
||||
}
|
||||
|
||||
func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap {
|
||||
hm := newHostMap(l, vpnCIDR)
|
||||
func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
|
||||
hm := newHostMap(l)
|
||||
|
||||
hm.reload(c, true)
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
hm.reload(c, false)
|
||||
})
|
||||
|
||||
l.WithField("network", hm.vpnCIDR.String()).
|
||||
WithField("preferredRanges", hm.GetPreferredRanges()).
|
||||
l.WithField("preferredRanges", hm.GetPreferredRanges()).
|
||||
Info("Main HostMap created")
|
||||
|
||||
return hm
|
||||
}
|
||||
|
||||
func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap {
|
||||
func newHostMap(l *logrus.Logger) *HostMap {
|
||||
return &HostMap{
|
||||
Indexes: map[uint32]*HostInfo{},
|
||||
Relays: map[uint32]*HostInfo{},
|
||||
RemoteIndexes: map[uint32]*HostInfo{},
|
||||
Hosts: map[netip.Addr]*HostInfo{},
|
||||
vpnCIDR: vpnCIDR,
|
||||
l: l,
|
||||
}
|
||||
}
|
||||
@ -305,17 +334,6 @@ func (hm *HostMap) EmitStats() {
|
||||
metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
|
||||
}
|
||||
|
||||
func (hm *HostMap) RemoveRelay(localIdx uint32) {
|
||||
hm.Lock()
|
||||
_, ok := hm.Relays[localIdx]
|
||||
if !ok {
|
||||
hm.Unlock()
|
||||
return
|
||||
}
|
||||
delete(hm.Relays, localIdx)
|
||||
hm.Unlock()
|
||||
}
|
||||
|
||||
// DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
|
||||
func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
|
||||
// Delete the host itself, ensuring it's not modified anymore
|
||||
@ -335,48 +353,73 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
|
||||
}
|
||||
|
||||
func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
|
||||
oldHostinfo := hm.Hosts[hostinfo.vpnIp]
|
||||
// Get the current primary, if it exists
|
||||
oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]]
|
||||
|
||||
// Every address in the hostinfo gets elevated to primary
|
||||
for _, vpnAddr := range hostinfo.vpnAddrs {
|
||||
//NOTE: It is possible that we leave a dangling hostinfo here but connection manager works on
|
||||
// indexes so it should be fine.
|
||||
hm.Hosts[vpnAddr] = hostinfo
|
||||
}
|
||||
|
||||
// If we are already primary then we won't bother re-linking
|
||||
if oldHostinfo == hostinfo {
|
||||
return
|
||||
}
|
||||
|
||||
// Unlink this hostinfo
|
||||
if hostinfo.prev != nil {
|
||||
hostinfo.prev.next = hostinfo.next
|
||||
}
|
||||
|
||||
if hostinfo.next != nil {
|
||||
hostinfo.next.prev = hostinfo.prev
|
||||
}
|
||||
|
||||
hm.Hosts[hostinfo.vpnIp] = hostinfo
|
||||
|
||||
// If there wasn't a previous primary then clear out any links
|
||||
if oldHostinfo == nil {
|
||||
hostinfo.next = nil
|
||||
hostinfo.prev = nil
|
||||
return
|
||||
}
|
||||
|
||||
// Relink the hostinfo as primary
|
||||
hostinfo.next = oldHostinfo
|
||||
oldHostinfo.prev = hostinfo
|
||||
hostinfo.prev = nil
|
||||
}
|
||||
|
||||
func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
||||
primary, ok := hm.Hosts[hostinfo.vpnIp]
|
||||
for _, addr := range hostinfo.vpnAddrs {
|
||||
h := hm.Hosts[addr]
|
||||
for h != nil {
|
||||
if h == hostinfo {
|
||||
hm.unlockedInnerDeleteHostInfo(h, addr)
|
||||
}
|
||||
h = h.next
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) {
|
||||
primary, ok := hm.Hosts[addr]
|
||||
isLastHostinfo := hostinfo.next == nil && hostinfo.prev == nil
|
||||
if ok && primary == hostinfo {
|
||||
// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
|
||||
delete(hm.Hosts, hostinfo.vpnIp)
|
||||
// The vpn addr pointer points to the same hostinfo as the local index id, we can remove it
|
||||
delete(hm.Hosts, addr)
|
||||
if len(hm.Hosts) == 0 {
|
||||
hm.Hosts = map[netip.Addr]*HostInfo{}
|
||||
}
|
||||
|
||||
if hostinfo.next != nil {
|
||||
// We had more than 1 hostinfo at this vpnip, promote the next in the list to primary
|
||||
hm.Hosts[hostinfo.vpnIp] = hostinfo.next
|
||||
// We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary
|
||||
hm.Hosts[addr] = hostinfo.next
|
||||
// It is primary, there is no previous hostinfo now
|
||||
hostinfo.next.prev = nil
|
||||
}
|
||||
|
||||
} else {
|
||||
// Relink if we were in the middle of multiple hostinfos for this vpn ip
|
||||
// Relink if we were in the middle of multiple hostinfos for this vpn addr
|
||||
if hostinfo.prev != nil {
|
||||
hostinfo.prev.next = hostinfo.next
|
||||
}
|
||||
@ -406,10 +449,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
||||
|
||||
if hm.l.Level >= logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
|
||||
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
||||
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
||||
Debug("Hostmap hostInfo deleted")
|
||||
}
|
||||
|
||||
if isLastHostinfo {
|
||||
// I have lost connectivity to my peers. My relay tunnel is likely broken. Mark the next
|
||||
// hops as 'Requested' so that new relay tunnels are created in the future.
|
||||
hm.unlockedDisestablishVpnAddrRelayFor(hostinfo)
|
||||
}
|
||||
// Clean up any local relay indexes for which I am acting as a relay hop
|
||||
for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() {
|
||||
delete(hm.Relays, localRelayIdx)
|
||||
}
|
||||
@ -448,11 +497,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
|
||||
return hm.queryVpnIp(vpnIp, nil)
|
||||
func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
|
||||
return hm.queryVpnAddr(vpnIp, nil)
|
||||
}
|
||||
|
||||
func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
|
||||
func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
|
||||
hm.RLock()
|
||||
defer hm.RUnlock()
|
||||
|
||||
@ -460,17 +509,42 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostIn
|
||||
if !ok {
|
||||
return nil, nil, errors.New("unable to find host")
|
||||
}
|
||||
|
||||
for h != nil {
|
||||
r, ok := h.relayState.QueryRelayForByIp(targetIp)
|
||||
if ok && r.State == Established {
|
||||
return h, r, nil
|
||||
for _, targetIp := range targetIps {
|
||||
r, ok := h.relayState.QueryRelayForByIp(targetIp)
|
||||
if ok && r.State == Established {
|
||||
return h, r, nil
|
||||
}
|
||||
}
|
||||
h = h.next
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("unable to find host with relay")
|
||||
}
|
||||
|
||||
func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
|
||||
func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {
|
||||
for _, relayHostIp := range hi.relayState.CopyRelayIps() {
|
||||
if h, ok := hm.Hosts[relayHostIp]; ok {
|
||||
for h != nil {
|
||||
h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished)
|
||||
h = h.next
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, rs := range hi.relayState.CopyAllRelayFor() {
|
||||
if rs.Type == ForwardingType {
|
||||
if h, ok := hm.Hosts[rs.PeerAddr]; ok {
|
||||
for h != nil {
|
||||
h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished)
|
||||
h = h.next
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
|
||||
hm.RLock()
|
||||
if h, ok := hm.Hosts[vpnIp]; ok {
|
||||
hm.RUnlock()
|
||||
@ -491,25 +565,30 @@ func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInf
|
||||
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
||||
if f.serveDns {
|
||||
remoteCert := hostinfo.ConnectionState.peerCert
|
||||
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
|
||||
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
|
||||
}
|
||||
|
||||
existing := hm.Hosts[hostinfo.vpnIp]
|
||||
hm.Hosts[hostinfo.vpnIp] = hostinfo
|
||||
|
||||
if existing != nil {
|
||||
hostinfo.next = existing
|
||||
existing.prev = hostinfo
|
||||
for _, addr := range hostinfo.vpnAddrs {
|
||||
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
|
||||
}
|
||||
|
||||
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
||||
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
||||
|
||||
if hm.l.Level >= logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
|
||||
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
|
||||
hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
|
||||
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
|
||||
Debug("Hostmap vpnIp added")
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) {
|
||||
existing := hm.Hosts[vpnAddr]
|
||||
hm.Hosts[vpnAddr] = hostinfo
|
||||
|
||||
if existing != nil && existing != hostinfo {
|
||||
hostinfo.next = existing
|
||||
existing.prev = hostinfo
|
||||
}
|
||||
|
||||
i := 1
|
||||
check := hostinfo
|
||||
@ -527,7 +606,7 @@ func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
|
||||
return *hm.preferredRanges.Load()
|
||||
}
|
||||
|
||||
func (hm *HostMap) ForEachVpnIp(f controlEach) {
|
||||
func (hm *HostMap) ForEachVpnAddr(f controlEach) {
|
||||
hm.RLock()
|
||||
defer hm.RUnlock()
|
||||
|
||||
@ -581,11 +660,11 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac
|
||||
}
|
||||
|
||||
i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
|
||||
ifce.lightHouse.QueryServer(i.vpnIp)
|
||||
ifce.lightHouse.QueryServer(i.vpnAddrs[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (i *HostInfo) GetCert() *cert.NebulaCertificate {
|
||||
func (i *HostInfo) GetCert() *cert.CachedCertificate {
|
||||
if i.ConnectionState != nil {
|
||||
return i.ConnectionState.peerCert
|
||||
}
|
||||
@ -596,7 +675,7 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
|
||||
// We copy here because we likely got this remote from a source that reuses the object
|
||||
if i.remote != remote {
|
||||
i.remote = remote
|
||||
i.remotes.LearnRemote(i.vpnIp, remote)
|
||||
i.remotes.LearnRemote(i.vpnAddrs[0], remote)
|
||||
}
|
||||
}
|
||||
|
||||
@ -647,29 +726,20 @@ func (i *HostInfo) RecvErrorExceeded() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
|
||||
if len(c.Details.Ips) == 1 && len(c.Details.Subnets) == 0 {
|
||||
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
||||
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
||||
// Simple case, no CIDRTree needed
|
||||
return
|
||||
}
|
||||
|
||||
remoteCidr := new(bart.Table[struct{}])
|
||||
for _, ip := range c.Details.Ips {
|
||||
//TODO: IPV6-WORK what to do when ip is invalid?
|
||||
nip, _ := netip.AddrFromSlice(ip.IP)
|
||||
nip = nip.Unmap()
|
||||
bits, _ := ip.Mask.Size()
|
||||
remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
|
||||
i.networks = new(bart.Table[struct{}])
|
||||
for _, network := range networks {
|
||||
i.networks.Insert(network, struct{}{})
|
||||
}
|
||||
|
||||
for _, n := range c.Details.Subnets {
|
||||
//TODO: IPV6-WORK what to do when ip is invalid?
|
||||
nip, _ := netip.AddrFromSlice(n.IP)
|
||||
nip = nip.Unmap()
|
||||
bits, _ := n.Mask.Size()
|
||||
remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
|
||||
for _, network := range unsafeNetworks {
|
||||
i.networks.Insert(network, struct{}{})
|
||||
}
|
||||
i.remoteCidr = remoteCidr
|
||||
}
|
||||
|
||||
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
||||
@ -677,13 +747,13 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
||||
return logrus.NewEntry(l)
|
||||
}
|
||||
|
||||
li := l.WithField("vpnIp", i.vpnIp).
|
||||
li := l.WithField("vpnAddrs", i.vpnAddrs).
|
||||
WithField("localIndex", i.localIndexId).
|
||||
WithField("remoteIndex", i.remoteIndexId)
|
||||
|
||||
if connState := i.ConnectionState; connState != nil {
|
||||
if peerCert := connState.peerCert; peerCert != nil {
|
||||
li = li.WithField("certName", peerCert.Details.Name)
|
||||
li = li.WithField("certName", peerCert.Certificate.Name())
|
||||
}
|
||||
}
|
||||
|
||||
@ -692,9 +762,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
||||
|
||||
// Utility functions
|
||||
|
||||
func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
||||
func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
||||
//FIXME: This function is pretty garbage
|
||||
var ips []netip.Addr
|
||||
var finalAddrs []netip.Addr
|
||||
ifaces, _ := net.Interfaces()
|
||||
for _, i := range ifaces {
|
||||
allow := allowList.AllowName(i.Name)
|
||||
@ -706,39 +776,36 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
||||
continue
|
||||
}
|
||||
addrs, _ := i.Addrs()
|
||||
for _, addr := range addrs {
|
||||
var ip net.IP
|
||||
switch v := addr.(type) {
|
||||
for _, rawAddr := range addrs {
|
||||
var addr netip.Addr
|
||||
switch v := rawAddr.(type) {
|
||||
case *net.IPNet:
|
||||
//continue
|
||||
ip = v.IP
|
||||
addr, _ = netip.AddrFromSlice(v.IP)
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
addr, _ = netip.AddrFromSlice(v.IP)
|
||||
}
|
||||
|
||||
nip, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
if !addr.IsValid() {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("localIp", ip).Debug("ip was invalid for netip")
|
||||
l.WithField("localAddr", rawAddr).Debug("addr was invalid")
|
||||
}
|
||||
continue
|
||||
}
|
||||
nip = nip.Unmap()
|
||||
addr = addr.Unmap()
|
||||
|
||||
//TODO: Filtering out link local for now, this is probably the most correct thing
|
||||
//TODO: Would be nice to filter out SLAAC MAC based ips as well
|
||||
if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false {
|
||||
allow := allowList.Allow(nip)
|
||||
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
|
||||
isAllowed := allowList.Allow(addr)
|
||||
if l.Level >= logrus.TraceLevel {
|
||||
l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow")
|
||||
l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
|
||||
}
|
||||
if !allow {
|
||||
if !isAllowed {
|
||||
continue
|
||||
}
|
||||
|
||||
ips = append(ips, nip)
|
||||
finalAddrs = append(finalAddrs, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ips
|
||||
return finalAddrs
|
||||
}
|
||||
|
||||
@ -11,17 +11,14 @@ import (
|
||||
|
||||
func TestHostMap_MakePrimary(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
hm := newHostMap(
|
||||
l,
|
||||
netip.MustParsePrefix("10.0.0.1/24"),
|
||||
)
|
||||
hm := newHostMap(l)
|
||||
|
||||
f := &Interface{}
|
||||
|
||||
h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
|
||||
h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
|
||||
h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
|
||||
h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
|
||||
h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
|
||||
h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
|
||||
h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
|
||||
h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
|
||||
|
||||
hm.unlockedAddHostInfo(h4, f)
|
||||
hm.unlockedAddHostInfo(h3, f)
|
||||
@ -29,7 +26,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
|
||||
hm.unlockedAddHostInfo(h1, f)
|
||||
|
||||
// Make sure we go h1 -> h2 -> h3 -> h4
|
||||
prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h1.localIndexId, prim.localIndexId)
|
||||
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
@ -44,7 +41,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
|
||||
hm.MakePrimary(h3)
|
||||
|
||||
// Make sure we go h3 -> h1 -> h2 -> h4
|
||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h3.localIndexId, prim.localIndexId)
|
||||
assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
@ -59,7 +56,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
|
||||
hm.MakePrimary(h4)
|
||||
|
||||
// Make sure we go h4 -> h3 -> h1 -> h2
|
||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h4.localIndexId, prim.localIndexId)
|
||||
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
@ -74,7 +71,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
|
||||
hm.MakePrimary(h4)
|
||||
|
||||
// Make sure we go h4 -> h3 -> h1 -> h2
|
||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h4.localIndexId, prim.localIndexId)
|
||||
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
@ -88,19 +85,16 @@ func TestHostMap_MakePrimary(t *testing.T) {
|
||||
|
||||
func TestHostMap_DeleteHostInfo(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
hm := newHostMap(
|
||||
l,
|
||||
netip.MustParsePrefix("10.0.0.1/24"),
|
||||
)
|
||||
hm := newHostMap(l)
|
||||
|
||||
f := &Interface{}
|
||||
|
||||
h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
|
||||
h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
|
||||
h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
|
||||
h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
|
||||
h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5}
|
||||
h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6}
|
||||
h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
|
||||
h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
|
||||
h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
|
||||
h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
|
||||
h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5}
|
||||
h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6}
|
||||
|
||||
hm.unlockedAddHostInfo(h6, f)
|
||||
hm.unlockedAddHostInfo(h5, f)
|
||||
@ -116,7 +110,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
||||
assert.Nil(t, h)
|
||||
|
||||
// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
|
||||
prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h1.localIndexId, prim.localIndexId)
|
||||
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
@ -135,7 +129,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
||||
assert.Nil(t, h1.next)
|
||||
|
||||
// Make sure we go h2 -> h3 -> h4 -> h5
|
||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h2.localIndexId, prim.localIndexId)
|
||||
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
@ -153,7 +147,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
||||
assert.Nil(t, h3.next)
|
||||
|
||||
// Make sure we go h2 -> h4 -> h5
|
||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h2.localIndexId, prim.localIndexId)
|
||||
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
@ -169,7 +163,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
||||
assert.Nil(t, h5.next)
|
||||
|
||||
// Make sure we go h2 -> h4
|
||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h2.localIndexId, prim.localIndexId)
|
||||
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
@ -183,7 +177,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
||||
assert.Nil(t, h2.next)
|
||||
|
||||
// Make sure we only have h4
|
||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Equal(t, h4.localIndexId, prim.localIndexId)
|
||||
assert.Nil(t, prim.prev)
|
||||
assert.Nil(t, prim.next)
|
||||
@ -195,7 +189,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
||||
assert.Nil(t, h4.next)
|
||||
|
||||
// Make sure we have nil
|
||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
||||
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||
assert.Nil(t, prim)
|
||||
}
|
||||
|
||||
@ -203,11 +197,7 @@ func TestHostMap_reload(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
|
||||
hm := NewHostMapFromConfig(
|
||||
l,
|
||||
netip.MustParsePrefix("10.0.0.1/24"),
|
||||
c,
|
||||
)
|
||||
hm := NewHostMapFromConfig(l, c)
|
||||
|
||||
toS := func(ipn []netip.Prefix) []string {
|
||||
var s []string
|
||||
|
||||
@ -9,8 +9,8 @@ import (
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func (i *HostInfo) GetVpnIp() netip.Addr {
|
||||
return i.vpnIp
|
||||
func (i *HostInfo) GetVpnAddrs() []netip.Addr {
|
||||
return i.vpnAddrs
|
||||
}
|
||||
|
||||
func (i *HostInfo) GetLocalIndex() uint32 {
|
||||
|
||||
58
inside.go
58
inside.go
@ -20,14 +20,18 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
}
|
||||
|
||||
// Ignore local broadcast packets
|
||||
if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr {
|
||||
return
|
||||
if f.dropLocalBroadcast {
|
||||
_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
|
||||
if found {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if fwPacket.RemoteIP == f.myVpnNet.Addr() {
|
||||
_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
|
||||
if found {
|
||||
// Immediately forward packets from self to self.
|
||||
// This should only happen on Darwin-based and FreeBSD hosts, which
|
||||
// routes packets from the Nebula IP to the Nebula IP through the Nebula
|
||||
// routes packets from the Nebula addr to the Nebula addr through the Nebula
|
||||
// TUN device.
|
||||
if immediatelyForwardToSelf {
|
||||
_, err := f.readers[q].Write(packet)
|
||||
@ -36,25 +40,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
}
|
||||
}
|
||||
// Otherwise, drop. On linux, we should never see these packets - Linux
|
||||
// routes packets from the nebula IP to the nebula IP through the loopback device.
|
||||
// routes packets from the nebula addr to the nebula addr through the loopback device.
|
||||
return
|
||||
}
|
||||
|
||||
// Ignore multicast packets
|
||||
if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() {
|
||||
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) {
|
||||
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||
})
|
||||
|
||||
if hostinfo == nil {
|
||||
f.rejectInside(packet, out, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnIp", fwPacket.RemoteIP).
|
||||
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
||||
WithField("fwPacket", fwPacket).
|
||||
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
|
||||
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -117,21 +121,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
||||
}
|
||||
|
||||
func (f *Interface) Handshake(vpnIp netip.Addr) {
|
||||
f.getOrHandshake(vpnIp, nil)
|
||||
func (f *Interface) Handshake(vpnAddr netip.Addr) {
|
||||
f.getOrHandshake(vpnAddr, nil)
|
||||
}
|
||||
|
||||
// getOrHandshake returns nil if the vpnIp is not routable.
|
||||
// getOrHandshake returns nil if the vpnAddr is not routable.
|
||||
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
|
||||
func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||
if !f.myVpnNet.Contains(vpnIp) {
|
||||
vpnIp = f.inside.RouteFor(vpnIp)
|
||||
if !vpnIp.IsValid() {
|
||||
func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||
_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
|
||||
if !found {
|
||||
vpnAddr = f.inside.RouteFor(vpnAddr)
|
||||
if !vpnAddr.IsValid() {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback)
|
||||
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
|
||||
}
|
||||
|
||||
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
|
||||
@ -156,16 +161,16 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
||||
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
||||
}
|
||||
|
||||
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
|
||||
func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
|
||||
hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
|
||||
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
|
||||
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
|
||||
hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
||||
})
|
||||
|
||||
if hostInfo == nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnIp", vpnIp).
|
||||
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
|
||||
f.l.WithField("vpnAddr", vpnAddr).
|
||||
Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -258,7 +263,6 @@ func (f *Interface) SendVia(via *HostInfo,
|
||||
|
||||
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
|
||||
if ci.eKey == nil {
|
||||
//TODO: log warning
|
||||
return
|
||||
}
|
||||
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
|
||||
@ -285,14 +289,14 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
||||
f.connectionManager.Out(hostinfo.localIndexId)
|
||||
|
||||
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
|
||||
// all our IPs and enable a faster roaming.
|
||||
// all our addrs and enable a faster roaming.
|
||||
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
|
||||
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
|
||||
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
||||
f.lightHouse.QueryServer(hostinfo.vpnIp)
|
||||
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
||||
hostinfo.lastRebindCount = f.rebindCount
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")
|
||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
|
||||
}
|
||||
}
|
||||
|
||||
@ -324,7 +328,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
||||
} else {
|
||||
// Try to send via a relay
|
||||
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
||||
relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP)
|
||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||
if err != nil {
|
||||
hostinfo.relayState.DeleteRelay(relayIP)
|
||||
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
||||
|
||||
160
interface.go
160
interface.go
@ -2,7 +2,6 @@ package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -12,6 +11,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
@ -28,7 +28,6 @@ type InterfaceConfig struct {
|
||||
Outside udp.Conn
|
||||
Inside overlay.Device
|
||||
pki *PKI
|
||||
Cipher string
|
||||
Firewall *Firewall
|
||||
ServeDns bool
|
||||
HandshakeManager *HandshakeManager
|
||||
@ -52,25 +51,27 @@ type InterfaceConfig struct {
|
||||
}
|
||||
|
||||
type Interface struct {
|
||||
hostMap *HostMap
|
||||
outside udp.Conn
|
||||
inside overlay.Device
|
||||
pki *PKI
|
||||
cipher string
|
||||
firewall *Firewall
|
||||
connectionManager *connectionManager
|
||||
handshakeManager *HandshakeManager
|
||||
serveDns bool
|
||||
createTime time.Time
|
||||
lightHouse *LightHouse
|
||||
myBroadcastAddr netip.Addr
|
||||
myVpnNet netip.Prefix
|
||||
dropLocalBroadcast bool
|
||||
dropMulticast bool
|
||||
routines int
|
||||
disconnectInvalid atomic.Bool
|
||||
closed atomic.Bool
|
||||
relayManager *relayManager
|
||||
hostMap *HostMap
|
||||
outside udp.Conn
|
||||
inside overlay.Device
|
||||
pki *PKI
|
||||
firewall *Firewall
|
||||
connectionManager *connectionManager
|
||||
handshakeManager *HandshakeManager
|
||||
serveDns bool
|
||||
createTime time.Time
|
||||
lightHouse *LightHouse
|
||||
myBroadcastAddrsTable *bart.Table[struct{}]
|
||||
myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
|
||||
myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
|
||||
myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate
|
||||
myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate
|
||||
dropLocalBroadcast bool
|
||||
dropMulticast bool
|
||||
routines int
|
||||
disconnectInvalid atomic.Bool
|
||||
closed atomic.Bool
|
||||
relayManager *relayManager
|
||||
|
||||
tryPromoteEvery atomic.Uint32
|
||||
reQueryEvery atomic.Uint32
|
||||
@ -102,9 +103,11 @@ type EncWriter interface {
|
||||
out []byte,
|
||||
nocopy bool,
|
||||
)
|
||||
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte)
|
||||
SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte)
|
||||
SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
|
||||
Handshake(vpnIp netip.Addr)
|
||||
Handshake(vpnAddr netip.Addr)
|
||||
GetHostInfo(vpnAddr netip.Addr) *HostInfo
|
||||
GetCertState() *CertState
|
||||
}
|
||||
|
||||
type sendRecvErrorConfig uint8
|
||||
@ -115,10 +118,10 @@ const (
|
||||
sendRecvErrorPrivate
|
||||
)
|
||||
|
||||
func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool {
|
||||
func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool {
|
||||
switch s {
|
||||
case sendRecvErrorPrivate:
|
||||
return ip.Addr().IsPrivate()
|
||||
return endpoint.Addr().IsPrivate()
|
||||
case sendRecvErrorAlways:
|
||||
return true
|
||||
case sendRecvErrorNever:
|
||||
@ -155,47 +158,29 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
return nil, errors.New("no firewall rules")
|
||||
}
|
||||
|
||||
certificate := c.pki.GetCertState().Certificate
|
||||
|
||||
myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP)
|
||||
}
|
||||
|
||||
myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask)
|
||||
}
|
||||
|
||||
myVpnAddr = myVpnAddr.Unmap()
|
||||
myVpnMask = myVpnMask.Unmap()
|
||||
|
||||
if myVpnAddr.BitLen() != myVpnMask.BitLen() {
|
||||
return nil, fmt.Errorf("ip address and mask are different lengths in certificate")
|
||||
}
|
||||
|
||||
ones, _ := certificate.Details.Ips[0].Mask.Size()
|
||||
myVpnNet := netip.PrefixFrom(myVpnAddr, ones)
|
||||
|
||||
cs := c.pki.getCertState()
|
||||
ifce := &Interface{
|
||||
pki: c.pki,
|
||||
hostMap: c.HostMap,
|
||||
outside: c.Outside,
|
||||
inside: c.Inside,
|
||||
cipher: c.Cipher,
|
||||
firewall: c.Firewall,
|
||||
serveDns: c.ServeDns,
|
||||
handshakeManager: c.HandshakeManager,
|
||||
createTime: time.Now(),
|
||||
lightHouse: c.lightHouse,
|
||||
dropLocalBroadcast: c.DropLocalBroadcast,
|
||||
dropMulticast: c.DropMulticast,
|
||||
routines: c.routines,
|
||||
version: c.version,
|
||||
writers: make([]udp.Conn, c.routines),
|
||||
readers: make([]io.ReadWriteCloser, c.routines),
|
||||
myVpnNet: myVpnNet,
|
||||
relayManager: c.relayManager,
|
||||
pki: c.pki,
|
||||
hostMap: c.HostMap,
|
||||
outside: c.Outside,
|
||||
inside: c.Inside,
|
||||
firewall: c.Firewall,
|
||||
serveDns: c.ServeDns,
|
||||
handshakeManager: c.HandshakeManager,
|
||||
createTime: time.Now(),
|
||||
lightHouse: c.lightHouse,
|
||||
dropLocalBroadcast: c.DropLocalBroadcast,
|
||||
dropMulticast: c.DropMulticast,
|
||||
routines: c.routines,
|
||||
version: c.version,
|
||||
writers: make([]udp.Conn, c.routines),
|
||||
readers: make([]io.ReadWriteCloser, c.routines),
|
||||
myVpnNetworks: cs.myVpnNetworks,
|
||||
myVpnNetworksTable: cs.myVpnNetworksTable,
|
||||
myVpnAddrs: cs.myVpnAddrs,
|
||||
myVpnAddrsTable: cs.myVpnAddrsTable,
|
||||
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
|
||||
relayManager: c.relayManager,
|
||||
|
||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||
|
||||
@ -209,12 +194,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
l: c.l,
|
||||
}
|
||||
|
||||
if myVpnAddr.Is4() {
|
||||
addr := myVpnNet.Masked().Addr().As4()
|
||||
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask))
|
||||
ifce.myBroadcastAddr = netip.AddrFrom4(addr)
|
||||
}
|
||||
|
||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
||||
@ -235,7 +214,7 @@ func (f *Interface) activate() {
|
||||
f.l.WithError(err).Error("Failed to get udp listen address")
|
||||
}
|
||||
|
||||
f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()).
|
||||
f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
|
||||
WithField("build", f.version).WithField("udpAddr", addr).
|
||||
WithField("boringcrypto", boringEnabled()).
|
||||
Info("Nebula interface is active")
|
||||
@ -276,16 +255,22 @@ func (f *Interface) listenOut(i int) {
|
||||
runtime.LockOSThread()
|
||||
|
||||
var li udp.Conn
|
||||
// TODO clean this up with a coherent interface for each outside connection
|
||||
if i > 0 {
|
||||
li = f.writers[i]
|
||||
} else {
|
||||
li = f.outside
|
||||
}
|
||||
|
||||
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
lhh := f.lightHouse.NewRequestHandler()
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
|
||||
plaintext := make([]byte, udp.MTU)
|
||||
h := &header.H{}
|
||||
fwPacket := &firewall.Packet{}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||
})
|
||||
}
|
||||
|
||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
@ -342,7 +327,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
||||
return
|
||||
}
|
||||
|
||||
fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c)
|
||||
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Error while creating firewall during reload")
|
||||
return
|
||||
@ -425,6 +410,8 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
||||
udpStats := udp.NewUDPStatsEmitter(f.writers)
|
||||
|
||||
certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
|
||||
certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil)
|
||||
certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
|
||||
|
||||
for {
|
||||
select {
|
||||
@ -434,11 +421,30 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
||||
f.firewall.EmitStats()
|
||||
f.handshakeManager.EmitStats()
|
||||
udpStats()
|
||||
certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
|
||||
|
||||
certState := f.pki.getCertState()
|
||||
defaultCrt := certState.GetDefaultCertificate()
|
||||
certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
|
||||
certDefaultVersion.Update(int64(defaultCrt.Version()))
|
||||
|
||||
// Report the max certificate version we are capable of using
|
||||
if certState.v2Cert != nil {
|
||||
certMaxVersion.Update(int64(certState.v2Cert.Version()))
|
||||
} else {
|
||||
certMaxVersion.Update(int64(certState.v1Cert.Version()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo {
|
||||
return f.hostMap.QueryVpnAddr(vpnIp)
|
||||
}
|
||||
|
||||
func (f *Interface) GetCertState() *CertState {
|
||||
return f.pki.getCertState()
|
||||
}
|
||||
|
||||
func (f *Interface) Close() error {
|
||||
f.closed.Store(true)
|
||||
|
||||
|
||||
@ -6,8 +6,6 @@ import (
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
//TODO: IPV6-WORK can probably delete this
|
||||
|
||||
const (
|
||||
// Need 96 bytes for the largest reject packet:
|
||||
// - 20 byte ipv4 header
|
||||
|
||||
944
lighthouse.go
944
lighthouse.go
File diff suppressed because it is too large
Load Diff
@ -7,69 +7,61 @@ import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
//TODO: Add a test to ensure udpAddr is copied and not reused
|
||||
|
||||
func TestOldIPv4Only(t *testing.T) {
|
||||
// This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility
|
||||
b := []byte{8, 129, 130, 132, 80, 16, 10}
|
||||
var m Ip4AndPort
|
||||
var m V4AddrPort
|
||||
err := m.Unmarshal(b)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
ip := netip.MustParseAddr("10.1.1.1")
|
||||
bp := ip.As4()
|
||||
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp())
|
||||
}
|
||||
|
||||
func TestNewLhQuery(t *testing.T) {
|
||||
myIp, err := netip.ParseAddr("192.1.1.1")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Generating a new lh query should work
|
||||
a := NewLhQueryByInt(myIp)
|
||||
|
||||
// The result should be a nebulameta protobuf
|
||||
assert.IsType(t, &NebulaMeta{}, a)
|
||||
|
||||
// It should also Marshal fine
|
||||
b, err := a.Marshal()
|
||||
assert.Nil(t, err)
|
||||
|
||||
// and then Unmarshal fine
|
||||
n := &NebulaMeta{}
|
||||
err = n.Unmarshal(b)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
|
||||
}
|
||||
|
||||
func Test_lhStaticMapping(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
||||
nt := new(bart.Table[struct{}])
|
||||
nt.Insert(myVpnNet, struct{}{})
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
}
|
||||
lh1 := "10.128.0.2"
|
||||
|
||||
c := config.NewC(l)
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
|
||||
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
|
||||
_, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
|
||||
assert.Nil(t, err)
|
||||
_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
lh2 := "10.128.0.3"
|
||||
c = config.NewC(l)
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
|
||||
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
|
||||
_, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
|
||||
assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
|
||||
_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||
require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
|
||||
}
|
||||
|
||||
func TestReloadLighthouseInterval(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
||||
nt := new(bart.Table[struct{}])
|
||||
nt.Insert(myVpnNet, struct{}{})
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
}
|
||||
lh1 := "10.128.0.2"
|
||||
|
||||
c := config.NewC(l)
|
||||
@ -79,77 +71,82 @@ func TestReloadLighthouseInterval(t *testing.T) {
|
||||
}
|
||||
|
||||
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||
require.NoError(t, err)
|
||||
lh.ifce = &mockEncWriter{}
|
||||
|
||||
// The first one routine is kicked off by main.go currently, lets make sure that one dies
|
||||
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5"))
|
||||
require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5"))
|
||||
assert.Equal(t, int64(5), lh.interval.Load())
|
||||
|
||||
// Subsequent calls are killed off by the LightHouse.Reload function
|
||||
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10"))
|
||||
require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10"))
|
||||
assert.Equal(t, int64(10), lh.interval.Load())
|
||||
|
||||
// If this completes then nothing is stealing our reload routine
|
||||
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11"))
|
||||
require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11"))
|
||||
assert.Equal(t, int64(11), lh.interval.Load())
|
||||
}
|
||||
|
||||
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||
l := test.NewLogger()
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
|
||||
nt := new(bart.Table[struct{}])
|
||||
nt.Insert(myVpnNet, struct{}{})
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
}
|
||||
|
||||
c := config.NewC(l)
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
|
||||
if !assert.NoError(b, err) {
|
||||
b.Fatal()
|
||||
}
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||
require.NoError(b, err)
|
||||
|
||||
hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
|
||||
hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
|
||||
|
||||
vpnIp3 := netip.MustParseAddr("0.0.0.3")
|
||||
lh.addrMap[vpnIp3] = NewRemoteList(nil)
|
||||
lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil)
|
||||
lh.addrMap[vpnIp3].unlockedSetV4(
|
||||
vpnIp3,
|
||||
vpnIp3,
|
||||
[]*Ip4AndPort{
|
||||
NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()),
|
||||
NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()),
|
||||
[]*V4AddrPort{
|
||||
netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()),
|
||||
netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()),
|
||||
},
|
||||
func(netip.Addr, *Ip4AndPort) bool { return true },
|
||||
func(netip.Addr, *V4AddrPort) bool { return true },
|
||||
)
|
||||
|
||||
rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
|
||||
rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
|
||||
vpnIp2 := netip.MustParseAddr("0.0.0.3")
|
||||
lh.addrMap[vpnIp2] = NewRemoteList(nil)
|
||||
lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil)
|
||||
lh.addrMap[vpnIp2].unlockedSetV4(
|
||||
vpnIp3,
|
||||
vpnIp3,
|
||||
[]*Ip4AndPort{
|
||||
NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()),
|
||||
NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()),
|
||||
[]*V4AddrPort{
|
||||
netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()),
|
||||
netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()),
|
||||
},
|
||||
func(netip.Addr, *Ip4AndPort) bool { return true },
|
||||
func(netip.Addr, *V4AddrPort) bool { return true },
|
||||
)
|
||||
|
||||
mw := &mockEncWriter{}
|
||||
|
||||
hi := []netip.Addr{vpnIp2}
|
||||
b.Run("notfound", func(b *testing.B) {
|
||||
lhh := lh.NewRequestHandler()
|
||||
req := &NebulaMeta{
|
||||
Type: NebulaMeta_HostQuery,
|
||||
Details: &NebulaMetaDetails{
|
||||
VpnIp: 4,
|
||||
Ip4AndPorts: nil,
|
||||
OldVpnAddr: 4,
|
||||
V4AddrPorts: nil,
|
||||
},
|
||||
}
|
||||
p, err := req.Marshal()
|
||||
assert.NoError(b, err)
|
||||
require.NoError(b, err)
|
||||
for n := 0; n < b.N; n++ {
|
||||
lhh.HandleRequest(rAddr, vpnIp2, p, mw)
|
||||
lhh.HandleRequest(rAddr, hi, p, mw)
|
||||
}
|
||||
})
|
||||
b.Run("found", func(b *testing.B) {
|
||||
@ -157,15 +154,15 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||
req := &NebulaMeta{
|
||||
Type: NebulaMeta_HostQuery,
|
||||
Details: &NebulaMetaDetails{
|
||||
VpnIp: 3,
|
||||
Ip4AndPorts: nil,
|
||||
OldVpnAddr: 3,
|
||||
V4AddrPorts: nil,
|
||||
},
|
||||
}
|
||||
p, err := req.Marshal()
|
||||
assert.NoError(b, err)
|
||||
require.NoError(b, err)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
lhh.HandleRequest(rAddr, vpnIp2, p, mw)
|
||||
lhh.HandleRequest(rAddr, hi, p, mw)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -197,40 +194,49 @@ func TestLighthouse_Memory(t *testing.T) {
|
||||
c := config.NewC(l)
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
|
||||
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||
nt := new(bart.Table[struct{}])
|
||||
nt.Insert(myVpnNet, struct{}{})
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
}
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||
lh.ifce = &mockEncWriter{}
|
||||
require.NoError(t, err)
|
||||
lhh := lh.NewRequestHandler()
|
||||
|
||||
// Test that my first update responds with just that
|
||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
|
||||
r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
|
||||
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2)
|
||||
|
||||
// Ensure we don't accumulate addresses
|
||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
|
||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
|
||||
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3)
|
||||
|
||||
// Grow it back to 2
|
||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
|
||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
|
||||
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
|
||||
|
||||
// Update a different host and ask about it
|
||||
newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
|
||||
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
|
||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
|
||||
assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
|
||||
|
||||
// Have both hosts ask about the other
|
||||
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
|
||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
|
||||
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
|
||||
|
||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh)
|
||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
|
||||
assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
|
||||
|
||||
// Make sure we didn't get changed
|
||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
|
||||
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
|
||||
|
||||
// Ensure proper ordering and limiting
|
||||
// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
|
||||
@ -255,7 +261,7 @@ func TestLighthouse_Memory(t *testing.T) {
|
||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||
assertIp4InArray(
|
||||
t,
|
||||
r.msg.Details.Ip4AndPorts,
|
||||
r.msg.Details.V4AddrPorts,
|
||||
myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
|
||||
)
|
||||
|
||||
@ -265,7 +271,7 @@ func TestLighthouse_Memory(t *testing.T) {
|
||||
good := netip.MustParseAddrPort("1.128.0.99:4242")
|
||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
|
||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
|
||||
assertIp4InArray(t, r.msg.Details.V4AddrPorts, good)
|
||||
}
|
||||
|
||||
func TestLighthouse_reload(t *testing.T) {
|
||||
@ -273,8 +279,17 @@ func TestLighthouse_reload(t *testing.T) {
|
||||
c := config.NewC(l)
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
|
||||
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||
nt := new(bart.Table[struct{}])
|
||||
nt.Insert(myVpnNet, struct{}{})
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
}
|
||||
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
nc := map[interface{}]interface{}{
|
||||
"static_host_map": map[interface{}]interface{}{
|
||||
@ -282,21 +297,24 @@ func TestLighthouse_reload(t *testing.T) {
|
||||
},
|
||||
}
|
||||
rc, err := yaml.Marshal(nc)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
c.ReloadConfigString(string(rc))
|
||||
|
||||
err = lh.reload(c, false)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
|
||||
//TODO: IPV6-WORK
|
||||
bip := queryVpnIp.As4()
|
||||
req := &NebulaMeta{
|
||||
Type: NebulaMeta_HostQuery,
|
||||
Details: &NebulaMetaDetails{
|
||||
VpnIp: binary.BigEndian.Uint32(bip[:]),
|
||||
},
|
||||
Type: NebulaMeta_HostQuery,
|
||||
Details: &NebulaMetaDetails{},
|
||||
}
|
||||
|
||||
if queryVpnIp.Is4() {
|
||||
bip := queryVpnIp.As4()
|
||||
req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:])
|
||||
} else {
|
||||
req.Details.VpnAddr = netAddrToProtoAddr(queryVpnIp)
|
||||
}
|
||||
|
||||
b, err := req.Marshal()
|
||||
@ -308,23 +326,29 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
|
||||
w := &testEncWriter{
|
||||
metaFilter: &filter,
|
||||
}
|
||||
lhh.HandleRequest(fromAddr, myVpnIp, b, w)
|
||||
lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
|
||||
return w.lastReply
|
||||
}
|
||||
|
||||
func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) {
|
||||
//TODO: IPV6-WORK
|
||||
bip := vpnIp.As4()
|
||||
req := &NebulaMeta{
|
||||
Type: NebulaMeta_HostUpdateNotification,
|
||||
Details: &NebulaMetaDetails{
|
||||
VpnIp: binary.BigEndian.Uint32(bip[:]),
|
||||
Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
|
||||
},
|
||||
Type: NebulaMeta_HostUpdateNotification,
|
||||
Details: &NebulaMetaDetails{},
|
||||
}
|
||||
|
||||
for k, v := range addrs {
|
||||
req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port())
|
||||
if vpnIp.Is4() {
|
||||
bip := vpnIp.As4()
|
||||
req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:])
|
||||
} else {
|
||||
req.Details.VpnAddr = netAddrToProtoAddr(vpnIp)
|
||||
}
|
||||
|
||||
for _, v := range addrs {
|
||||
if v.Addr().Is4() {
|
||||
req.Details.V4AddrPorts = append(req.Details.V4AddrPorts, netAddrToProtoV4AddrPort(v.Addr(), v.Port()))
|
||||
} else {
|
||||
req.Details.V6AddrPorts = append(req.Details.V6AddrPorts, netAddrToProtoV6AddrPort(v.Addr(), v.Port()))
|
||||
}
|
||||
}
|
||||
|
||||
b, err := req.Marshal()
|
||||
@ -333,75 +357,9 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
|
||||
}
|
||||
|
||||
w := &testEncWriter{}
|
||||
lhh.HandleRequest(fromAddr, vpnIp, b, w)
|
||||
lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
|
||||
}
|
||||
|
||||
//TODO: this is a RemoteList test
|
||||
//func Test_lhRemoteAllowList(t *testing.T) {
|
||||
// l := NewLogger()
|
||||
// c := NewConfig(l)
|
||||
// c.Settings["remoteallowlist"] = map[interface{}]interface{}{
|
||||
// "10.20.0.0/12": false,
|
||||
// }
|
||||
// allowList, err := c.GetAllowList("remoteallowlist", false)
|
||||
// assert.Nil(t, err)
|
||||
//
|
||||
// lh1 := "10.128.0.2"
|
||||
// lh1IP := net.ParseIP(lh1)
|
||||
//
|
||||
// udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
|
||||
//
|
||||
// lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
||||
// lh.SetRemoteAllowList(allowList)
|
||||
//
|
||||
// // A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap
|
||||
// remote1IP := net.ParseIP("10.20.0.3")
|
||||
// remotes := lh.unlockedGetRemoteList(ip2int(remote1IP))
|
||||
// remotes.unlockedPrependV4(ip2int(remote1IP), NewIp4AndPort(remote1IP, 4242))
|
||||
// assert.NotNil(t, lh.addrMap[ip2int(remote1IP)])
|
||||
// assert.Empty(t, lh.addrMap[ip2int(remote1IP)].CopyAddrs([]*net.IPNet{}))
|
||||
//
|
||||
// // Make sure a good ip enters the cache and addrMap
|
||||
// remote2IP := net.ParseIP("10.128.0.3")
|
||||
// remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242))
|
||||
// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote2UDPAddr.IP, uint32(remote2UDPAddr.Port)), false, false)
|
||||
// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr)
|
||||
//
|
||||
// // Another good ip gets into the cache, ordering is inverted
|
||||
// remote3IP := net.ParseIP("10.128.0.4")
|
||||
// remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243))
|
||||
// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote3UDPAddr.IP, uint32(remote3UDPAddr.Port)), false, false)
|
||||
// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr, remote3UDPAddr)
|
||||
//
|
||||
// // If we exceed the length limit we should only have the most recent addresses
|
||||
// addedAddrs := []*udpAddr{}
|
||||
// for i := 0; i < 11; i++ {
|
||||
// remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i))
|
||||
// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remoteUDPAddr.IP, uint32(remoteUDPAddr.Port)), false, false)
|
||||
// // The first entry here is a duplicate, don't add it to the assert list
|
||||
// if i != 0 {
|
||||
// addedAddrs = append(addedAddrs, remoteUDPAddr)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // We should only have the last 10 of what we tried to add
|
||||
// assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses")
|
||||
// assertUdpAddrInArray(
|
||||
// t,
|
||||
// lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}),
|
||||
// addedAddrs[0],
|
||||
// addedAddrs[1],
|
||||
// addedAddrs[2],
|
||||
// addedAddrs[3],
|
||||
// addedAddrs[4],
|
||||
// addedAddrs[5],
|
||||
// addedAddrs[6],
|
||||
// addedAddrs[7],
|
||||
// addedAddrs[8],
|
||||
// addedAddrs[9],
|
||||
// )
|
||||
//}
|
||||
|
||||
type testLhReply struct {
|
||||
nebType header.MessageType
|
||||
nebSubType header.MessageSubType
|
||||
@ -410,8 +368,9 @@ type testLhReply struct {
|
||||
}
|
||||
|
||||
type testEncWriter struct {
|
||||
lastReply testLhReply
|
||||
metaFilter *NebulaMeta_MessageType
|
||||
lastReply testLhReply
|
||||
metaFilter *NebulaMeta_MessageType
|
||||
protocolVersion cert.Version
|
||||
}
|
||||
|
||||
func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
|
||||
@ -426,7 +385,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
|
||||
tw.lastReply = testLhReply{
|
||||
nebType: t,
|
||||
nebSubType: st,
|
||||
vpnIp: hostinfo.vpnIp,
|
||||
vpnIp: hostinfo.vpnAddrs[0],
|
||||
msg: msg,
|
||||
}
|
||||
}
|
||||
@ -436,7 +395,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
|
||||
}
|
||||
}
|
||||
|
||||
func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
|
||||
func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
|
||||
msg := &NebulaMeta{}
|
||||
err := msg.Unmarshal(p)
|
||||
if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
|
||||
@ -453,17 +412,84 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
|
||||
}
|
||||
}
|
||||
|
||||
func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tw *testEncWriter) GetCertState() *CertState {
|
||||
return &CertState{defaultVersion: tw.protocolVersion}
|
||||
}
|
||||
|
||||
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
|
||||
func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) {
|
||||
func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) {
|
||||
if !assert.Len(t, have, len(want)) {
|
||||
return
|
||||
}
|
||||
|
||||
for k, w := range want {
|
||||
//TODO: IPV6-WORK
|
||||
h := AddrPortFromIp4AndPort(have[k])
|
||||
h := protoV4AddrPortToNetAddrPort(have[k])
|
||||
if !(h == w) {
|
||||
assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_findNetworkUnion(t *testing.T) {
|
||||
var out netip.Addr
|
||||
var ok bool
|
||||
|
||||
tenDot := netip.MustParsePrefix("10.0.0.0/8")
|
||||
oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16")
|
||||
fe80 := netip.MustParsePrefix("fe80::/8")
|
||||
fc00 := netip.MustParsePrefix("fc00::/7")
|
||||
|
||||
a1 := netip.MustParseAddr("10.0.0.1")
|
||||
afe81 := netip.MustParseAddr("fe80::1")
|
||||
|
||||
//simple
|
||||
out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1})
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, out, a1)
|
||||
|
||||
//mixed lengths
|
||||
out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81})
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, out, a1)
|
||||
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1})
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, out, a1)
|
||||
|
||||
//mixed family
|
||||
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1})
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, out, a1)
|
||||
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, out, a1)
|
||||
|
||||
//ordering
|
||||
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1})
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, out, a1)
|
||||
out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1})
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, out, afe81)
|
||||
|
||||
//some mismatches
|
||||
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81})
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, out, afe81)
|
||||
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, out, afe81)
|
||||
|
||||
//falsey cases
|
||||
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1})
|
||||
assert.False(t, ok)
|
||||
out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1})
|
||||
assert.False(t, ok)
|
||||
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81})
|
||||
assert.False(t, ok)
|
||||
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
39
main.go
39
main.go
@ -2,7 +2,6 @@ package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@ -61,25 +60,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
|
||||
}
|
||||
|
||||
certificate := pki.GetCertState().Certificate
|
||||
fw, err := NewFirewallFromConfig(l, certificate, c)
|
||||
fw, err := NewFirewallFromConfig(l, pki.getCertState(), c)
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
|
||||
}
|
||||
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
|
||||
|
||||
ones, _ := certificate.Details.Ips[0].Mask.Size()
|
||||
addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
|
||||
if !ok {
|
||||
err = util.NewContextualError(
|
||||
"Invalid ip address in certificate",
|
||||
m{"vpnIp": certificate.Details.Ips[0].IP},
|
||||
nil,
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
tunCidr := netip.PrefixFrom(addr, ones)
|
||||
|
||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
|
||||
@ -142,7 +128,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
deviceFactory = overlay.NewDeviceFromConfig
|
||||
}
|
||||
|
||||
tun, err = deviceFactory(c, l, tunCidr, routines)
|
||||
tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
|
||||
}
|
||||
@ -197,9 +183,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
}
|
||||
}
|
||||
|
||||
hostMap := NewHostMapFromConfig(l, tunCidr, c)
|
||||
hostMap := NewHostMapFromConfig(l, c)
|
||||
punchy := NewPunchyFromConfig(l, c)
|
||||
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
|
||||
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
|
||||
}
|
||||
@ -242,7 +228,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
Inside: tun,
|
||||
Outside: udpConns[0],
|
||||
pki: pki,
|
||||
Cipher: c.GetString("cipher", "aes"),
|
||||
Firewall: fw,
|
||||
ServeDns: serveDns,
|
||||
HandshakeManager: handshakeManager,
|
||||
@ -264,15 +249,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
l: l,
|
||||
}
|
||||
|
||||
switch ifConfig.Cipher {
|
||||
case "aes":
|
||||
noiseEndianness = binary.BigEndian
|
||||
case "chachapoly":
|
||||
noiseEndianness = binary.LittleEndian
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
|
||||
}
|
||||
|
||||
var ifce *Interface
|
||||
if !configTest {
|
||||
ifce, err = NewInterface(ctx, ifConfig)
|
||||
@ -280,8 +256,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
return nil, fmt.Errorf("failed to initialize interface: %s", err)
|
||||
}
|
||||
|
||||
// TODO: Better way to attach these, probably want a new interface in InterfaceConfig
|
||||
// I don't want to make this initial commit too far-reaching though
|
||||
ifce.writers = udpConns
|
||||
lightHouse.ifce = ifce
|
||||
|
||||
@ -293,8 +267,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
go handshakeManager.Run(ctx)
|
||||
}
|
||||
|
||||
// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
|
||||
// a context so that they can exit when the context is Done.
|
||||
statsStart, err := startStats(l, c, buildVersion, configTest)
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
|
||||
@ -304,7 +276,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
//TODO: check if we _should_ be emitting stats
|
||||
go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
|
||||
|
||||
attachCommands(l, c, ssh, ifce)
|
||||
@ -313,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
var dnsStart func()
|
||||
if lightHouse.amLighthouse && serveDns {
|
||||
l.Debugln("Starting dns server")
|
||||
dnsStart = dnsMain(l, hostMap, c)
|
||||
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
|
||||
}
|
||||
|
||||
return &Control{
|
||||
|
||||
@ -7,8 +7,6 @@ import (
|
||||
"github.com/slackhq/nebula/header"
|
||||
)
|
||||
|
||||
//TODO: this can probably move into the header package
|
||||
|
||||
type MessageMetrics struct {
|
||||
rx [][]metrics.Counter
|
||||
tx [][]metrics.Counter
|
||||
|
||||
18
metadata.go
18
metadata.go
@ -1,18 +0,0 @@
|
||||
package nebula
|
||||
|
||||
/*
|
||||
|
||||
import (
|
||||
proto "google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func HandleMetaProto(p []byte) {
|
||||
m := &NebulaMeta{}
|
||||
err := proto.Unmarshal(p, m)
|
||||
if err != nil {
|
||||
l.Debugf("problem unmarshaling meta message: %s", err)
|
||||
}
|
||||
//fmt.Println(m)
|
||||
}
|
||||
|
||||
*/
|
||||
921
nebula.pb.go
921
nebula.pb.go
File diff suppressed because it is too large
Load Diff
32
nebula.proto
32
nebula.proto
@ -23,19 +23,28 @@ message NebulaMeta {
|
||||
}
|
||||
|
||||
message NebulaMetaDetails {
|
||||
uint32 VpnIp = 1;
|
||||
repeated Ip4AndPort Ip4AndPorts = 2;
|
||||
repeated Ip6AndPort Ip6AndPorts = 4;
|
||||
repeated uint32 RelayVpnIp = 5;
|
||||
uint32 OldVpnAddr = 1 [deprecated = true];
|
||||
Addr VpnAddr = 6;
|
||||
|
||||
repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true];
|
||||
repeated Addr RelayVpnAddrs = 7;
|
||||
|
||||
repeated V4AddrPort V4AddrPorts = 2;
|
||||
repeated V6AddrPort V6AddrPorts = 4;
|
||||
uint32 counter = 3;
|
||||
}
|
||||
|
||||
message Ip4AndPort {
|
||||
uint32 Ip = 1;
|
||||
message Addr {
|
||||
uint64 Hi = 1;
|
||||
uint64 Lo = 2;
|
||||
}
|
||||
|
||||
message V4AddrPort {
|
||||
uint32 Addr = 1;
|
||||
uint32 Port = 2;
|
||||
}
|
||||
|
||||
message Ip6AndPort {
|
||||
message V6AddrPort {
|
||||
uint64 Hi = 1;
|
||||
uint64 Lo = 2;
|
||||
uint32 Port = 3;
|
||||
@ -62,6 +71,7 @@ message NebulaHandshakeDetails {
|
||||
uint32 ResponderIndex = 3;
|
||||
uint64 Cookie = 4;
|
||||
uint64 Time = 5;
|
||||
uint32 CertVersion = 8;
|
||||
// reserved for WIP multiport
|
||||
reserved 6, 7;
|
||||
}
|
||||
@ -76,6 +86,10 @@ message NebulaControl {
|
||||
|
||||
uint32 InitiatorRelayIndex = 2;
|
||||
uint32 ResponderRelayIndex = 3;
|
||||
uint32 RelayToIp = 4;
|
||||
uint32 RelayFromIp = 5;
|
||||
|
||||
uint32 OldRelayToAddr = 4 [deprecated = true];
|
||||
uint32 OldRelayFromAddr = 5 [deprecated = true];
|
||||
|
||||
Addr RelayToAddr = 6;
|
||||
Addr RelayFromAddr = 7;
|
||||
}
|
||||
|
||||
289
outside.go
289
outside.go
@ -3,46 +3,25 @@ package nebula
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/google/gopacket/layers"
|
||||
"golang.org/x/net/ipv6"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"golang.org/x/net/ipv4"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
minFwPacketLen = 4
|
||||
)
|
||||
|
||||
// TODO: IPV6-WORK this can likely be removed now
|
||||
func readOutsidePackets(f *Interface) udp.EncReader {
|
||||
return func(
|
||||
addr netip.AddrPort,
|
||||
out []byte,
|
||||
packet []byte,
|
||||
header *header.H,
|
||||
fwPacket *firewall.Packet,
|
||||
lhh udp.LightHouseHandlerFunc,
|
||||
nb []byte,
|
||||
q int,
|
||||
localCache firewall.ConntrackCache,
|
||||
) {
|
||||
f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||
err := h.Parse(packet)
|
||||
if err != nil {
|
||||
// TODO: best if we return this and let caller log
|
||||
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
|
||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||
if len(packet) > 1 {
|
||||
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
|
||||
@ -52,7 +31,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
|
||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||
if ip.IsValid() {
|
||||
if f.myVpnNet.Contains(ip.Addr()) {
|
||||
_, found := f.myVpnNetworksTable.Lookup(ip.Addr())
|
||||
if found {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
|
||||
}
|
||||
@ -109,7 +89,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
if !ok {
|
||||
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
|
||||
// its internal mapping. This should never happen.
|
||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
|
||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
|
||||
return
|
||||
}
|
||||
|
||||
@ -121,9 +101,9 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
return
|
||||
case ForwardingType:
|
||||
// Find the target HostInfo relay object
|
||||
targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp)
|
||||
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip")
|
||||
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
|
||||
return
|
||||
}
|
||||
|
||||
@ -139,7 +119,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
||||
}
|
||||
} else {
|
||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
|
||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -156,13 +136,10 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
||||
WithField("packet", packet).
|
||||
Error("Failed to decrypt lighthouse packet")
|
||||
|
||||
//TODO: maybe after build 64 is out? 06/14/2018 - NB
|
||||
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
|
||||
return
|
||||
}
|
||||
|
||||
lhf(ip, hostinfo.vpnIp, d)
|
||||
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
|
||||
|
||||
// Fallthrough to the bottom to record incoming traffic
|
||||
|
||||
@ -177,9 +154,6 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
||||
WithField("packet", packet).
|
||||
Error("Failed to decrypt test packet")
|
||||
|
||||
//TODO: maybe after build 64 is out? 06/14/2018 - NB
|
||||
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
|
||||
return
|
||||
}
|
||||
|
||||
@ -229,14 +203,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
Error("Failed to decrypt Control packet")
|
||||
return
|
||||
}
|
||||
m := &NebulaControl{}
|
||||
err = m.Unmarshal(d)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
|
||||
break
|
||||
}
|
||||
|
||||
f.relayManager.HandleControlMsg(hostinfo, m, f)
|
||||
f.relayManager.HandleControlMsg(hostinfo, d, f)
|
||||
|
||||
default:
|
||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||
@ -253,8 +221,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
func (f *Interface) closeTunnel(hostInfo *HostInfo) {
|
||||
final := f.hostMap.DeleteHostInfo(hostInfo)
|
||||
if final {
|
||||
// We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage
|
||||
f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
|
||||
// We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage
|
||||
f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs)
|
||||
}
|
||||
}
|
||||
|
||||
@ -263,25 +231,26 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
|
||||
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||
}
|
||||
|
||||
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) {
|
||||
if ip.IsValid() && hostinfo.remote != ip {
|
||||
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) {
|
||||
hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming")
|
||||
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {
|
||||
if udpAddr.IsValid() && hostinfo.remote != udpAddr {
|
||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) {
|
||||
hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming")
|
||||
return
|
||||
}
|
||||
if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
||||
|
||||
if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
|
||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
|
||||
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
|
||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
|
||||
Info("Host roamed to new udp ip/port.")
|
||||
hostinfo.lastRoam = time.Now()
|
||||
hostinfo.lastRoamRemote = hostinfo.remote
|
||||
hostinfo.SetRemote(ip)
|
||||
hostinfo.SetRemote(udpAddr)
|
||||
}
|
||||
|
||||
}
|
||||
@ -301,24 +270,141 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h
|
||||
return true
|
||||
}
|
||||
|
||||
var (
|
||||
ErrPacketTooShort = errors.New("packet is too short")
|
||||
ErrUnknownIPVersion = errors.New("packet is an unknown ip version")
|
||||
ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length")
|
||||
ErrIPv4PacketTooShort = errors.New("ipv4 packet is too short")
|
||||
ErrIPv6PacketTooShort = errors.New("ipv6 packet is too short")
|
||||
ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet")
|
||||
)
|
||||
|
||||
// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
|
||||
func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
|
||||
// Do we at least have an ipv4 header worth of data?
|
||||
if len(data) < ipv4.HeaderLen {
|
||||
return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
|
||||
if len(data) < 1 {
|
||||
return ErrPacketTooShort
|
||||
}
|
||||
|
||||
// Is it an ipv4 packet?
|
||||
if int((data[0]>>4)&0x0f) != 4 {
|
||||
return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f))
|
||||
version := int((data[0] >> 4) & 0x0f)
|
||||
switch version {
|
||||
case ipv4.Version:
|
||||
return parseV4(data, incoming, fp)
|
||||
case ipv6.Version:
|
||||
return parseV6(data, incoming, fp)
|
||||
}
|
||||
return ErrUnknownIPVersion
|
||||
}
|
||||
|
||||
func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
||||
dataLen := len(data)
|
||||
if dataLen < ipv6.HeaderLen {
|
||||
return ErrIPv6PacketTooShort
|
||||
}
|
||||
|
||||
if incoming {
|
||||
fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24])
|
||||
fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40])
|
||||
} else {
|
||||
fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24])
|
||||
fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
|
||||
}
|
||||
|
||||
protoAt := 6 // NextHeader is at 6 bytes into the ipv6 header
|
||||
offset := ipv6.HeaderLen // Start at the end of the ipv6 header
|
||||
next := 0
|
||||
for {
|
||||
if dataLen < offset {
|
||||
break
|
||||
}
|
||||
|
||||
proto := layers.IPProtocol(data[protoAt])
|
||||
//fmt.Println(proto, protoAt)
|
||||
switch proto {
|
||||
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
|
||||
fp.Protocol = uint8(proto)
|
||||
fp.RemotePort = 0
|
||||
fp.LocalPort = 0
|
||||
fp.Fragment = false
|
||||
return nil
|
||||
|
||||
case layers.IPProtocolTCP, layers.IPProtocolUDP:
|
||||
if dataLen < offset+4 {
|
||||
return ErrIPv6PacketTooShort
|
||||
}
|
||||
|
||||
fp.Protocol = uint8(proto)
|
||||
if incoming {
|
||||
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
|
||||
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
||||
} else {
|
||||
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
|
||||
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
||||
}
|
||||
|
||||
fp.Fragment = false
|
||||
return nil
|
||||
|
||||
case layers.IPProtocolIPv6Fragment:
|
||||
// Fragment header is 8 bytes, need at least offset+4 to read the offset field
|
||||
if dataLen < offset+8 {
|
||||
return ErrIPv6PacketTooShort
|
||||
}
|
||||
|
||||
// Check if this is the first fragment
|
||||
fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) // Remove the reserved and M flag bits
|
||||
if fragmentOffset != 0 {
|
||||
// Non-first fragment, use what we have now and stop processing
|
||||
fp.Protocol = data[offset]
|
||||
fp.Fragment = true
|
||||
fp.RemotePort = 0
|
||||
fp.LocalPort = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
// The next loop should be the transport layer since we are the first fragment
|
||||
next = 8 // Fragment headers are always 8 bytes
|
||||
|
||||
case layers.IPProtocolAH:
|
||||
// Auth headers, used by IPSec, have a different meaning for header length
|
||||
if dataLen < offset+1 {
|
||||
break
|
||||
}
|
||||
|
||||
next = int(data[offset+1]+2) << 2
|
||||
|
||||
default:
|
||||
// Normal ipv6 header length processing
|
||||
if dataLen < offset+1 {
|
||||
break
|
||||
}
|
||||
|
||||
next = int(data[offset+1]+1) << 3
|
||||
}
|
||||
|
||||
if next <= 0 {
|
||||
// Safety check, each ipv6 header has to be at least 8 bytes
|
||||
next = 8
|
||||
}
|
||||
|
||||
protoAt = offset
|
||||
offset = offset + next
|
||||
}
|
||||
|
||||
return ErrIPv6CouldNotFindPayload
|
||||
}
|
||||
|
||||
func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
|
||||
// Do we at least have an ipv4 header worth of data?
|
||||
if len(data) < ipv4.HeaderLen {
|
||||
return ErrIPv4PacketTooShort
|
||||
}
|
||||
|
||||
// Adjust our start position based on the advertised ip header length
|
||||
ihl := int(data[0]&0x0f) << 2
|
||||
|
||||
// Well formed ip header length?
|
||||
// Well-formed ip header length?
|
||||
if ihl < ipv4.HeaderLen {
|
||||
return fmt.Errorf("packet had an invalid header length: %v", ihl)
|
||||
return ErrIPv4InvalidHeaderLength
|
||||
}
|
||||
|
||||
// Check if this is the second or further fragment of a fragmented packet.
|
||||
@ -334,14 +420,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
|
||||
minLen += minFwPacketLen
|
||||
}
|
||||
if len(data) < minLen {
|
||||
return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl)
|
||||
return ErrIPv4InvalidHeaderLength
|
||||
}
|
||||
|
||||
// Firewall packets are locally oriented
|
||||
if incoming {
|
||||
//TODO: IPV6-WORK
|
||||
fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16])
|
||||
fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
|
||||
fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
|
||||
fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
|
||||
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
||||
fp.RemotePort = 0
|
||||
fp.LocalPort = 0
|
||||
@ -350,9 +435,8 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
|
||||
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
|
||||
}
|
||||
} else {
|
||||
//TODO: IPV6-WORK
|
||||
fp.LocalIP, _ = netip.AddrFromSlice(data[12:16])
|
||||
fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
|
||||
fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
|
||||
fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
|
||||
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
||||
fp.RemotePort = 0
|
||||
fp.LocalPort = 0
|
||||
@ -387,8 +471,6 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||
//TODO: maybe after build 64 is out? 06/14/2018 - NB
|
||||
//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
|
||||
return false
|
||||
}
|
||||
|
||||
@ -435,9 +517,8 @@ func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
|
||||
func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
|
||||
f.messageMetrics.Tx(header.RecvError, 0, 1)
|
||||
|
||||
//TODO: this should be a signed message so we can trust that we should drop the index
|
||||
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
|
||||
f.outside.WriteTo(b, endpoint)
|
||||
_ = f.outside.WriteTo(b, endpoint)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("index", index).
|
||||
WithField("udpAddr", endpoint).
|
||||
@ -471,65 +552,3 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
||||
// We also delete it from pending hostmap to allow for fast reconnect.
|
||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||
}
|
||||
|
||||
/*
|
||||
func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *NebulaMeta) {
|
||||
if ci.eKey != nil {
|
||||
//TODO: log error?
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := proto.Marshal(meta)
|
||||
if err != nil {
|
||||
l.Debugln("failed to encode header")
|
||||
}
|
||||
|
||||
c := ci.messageCounter
|
||||
b := HeaderEncode(nil, Version, uint8(metadata), 0, hostinfo.remoteIndexId, c)
|
||||
ci.messageCounter++
|
||||
|
||||
msg := ci.eKey.EncryptDanger(b, nil, msg, c)
|
||||
//msg := ci.eKey.EncryptDanger(b, nil, []byte(fmt.Sprintf("%d", counter)), c)
|
||||
f.outside.WriteTo(msg, endpoint)
|
||||
}
|
||||
*/
|
||||
|
||||
func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.NebulaCAPool) (*cert.NebulaCertificate, error) {
|
||||
pk := h.PeerStatic()
|
||||
|
||||
if pk == nil {
|
||||
return nil, errors.New("no peer static key was present")
|
||||
}
|
||||
|
||||
if rawCertBytes == nil {
|
||||
return nil, errors.New("provided payload was empty")
|
||||
}
|
||||
|
||||
r := &cert.RawNebulaCertificate{}
|
||||
err := proto.Unmarshal(rawCertBytes, r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling cert: %s", err)
|
||||
}
|
||||
|
||||
// If the Details are nil, just exit to avoid crashing
|
||||
if r.Details == nil {
|
||||
return nil, fmt.Errorf("certificate did not contain any details")
|
||||
}
|
||||
|
||||
r.Details.PublicKey = pk
|
||||
recombined, err := proto.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while recombining certificate: %s", err)
|
||||
}
|
||||
|
||||
c, _ := cert.UnmarshalNebulaCertificate(recombined)
|
||||
isValid, err := c.Verify(time.Now(), caPool)
|
||||
if err != nil {
|
||||
return c, fmt.Errorf("certificate validation failed: %s", err)
|
||||
} else if !isValid {
|
||||
// This case should never happen but here's to defensive programming!
|
||||
return c, errors.New("certificate validation failed but did not return an error")
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
547
outside_test.go
547
outside_test.go
@ -1,21 +1,33 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
func Test_newPacket(t *testing.T) {
|
||||
p := &firewall.Packet{}
|
||||
|
||||
// length fail
|
||||
err := newPacket([]byte{0, 1}, true, p)
|
||||
assert.EqualError(t, err, "packet is less than 20 bytes")
|
||||
// length fails
|
||||
err := newPacket([]byte{}, true, p)
|
||||
require.ErrorIs(t, err, ErrPacketTooShort)
|
||||
|
||||
err = newPacket([]byte{0x40}, true, p)
|
||||
require.ErrorIs(t, err, ErrIPv4PacketTooShort)
|
||||
|
||||
err = newPacket([]byte{0x60}, true, p)
|
||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||
|
||||
// length fail with ip options
|
||||
h := ipv4.Header{
|
||||
@ -28,16 +40,15 @@ func Test_newPacket(t *testing.T) {
|
||||
|
||||
b, _ := h.Marshal()
|
||||
err = newPacket(b, true, p)
|
||||
|
||||
assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24")
|
||||
require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
||||
|
||||
// not an ipv4 packet
|
||||
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
||||
assert.EqualError(t, err, "packet is not ipv4, type: 0")
|
||||
require.ErrorIs(t, err, ErrUnknownIPVersion)
|
||||
|
||||
// invalid ihl
|
||||
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
||||
assert.EqualError(t, err, "packet had an invalid header length: 8")
|
||||
require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
||||
|
||||
// account for variable ip header length - incoming
|
||||
h = ipv4.Header{
|
||||
@ -53,12 +64,13 @@ func Test_newPacket(t *testing.T) {
|
||||
b = append(b, []byte{0, 3, 0, 4}...)
|
||||
err = newPacket(b, true, p)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
|
||||
assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2"))
|
||||
assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1"))
|
||||
assert.Equal(t, p.RemotePort, uint16(3))
|
||||
assert.Equal(t, p.LocalPort, uint16(4))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr)
|
||||
assert.Equal(t, uint16(3), p.RemotePort)
|
||||
assert.Equal(t, uint16(4), p.LocalPort)
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// account for variable ip header length - outgoing
|
||||
h = ipv4.Header{
|
||||
@ -74,10 +86,507 @@ func Test_newPacket(t *testing.T) {
|
||||
b = append(b, []byte{0, 5, 0, 6}...)
|
||||
err = newPacket(b, false, p)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, p.Protocol, uint8(2))
|
||||
assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1"))
|
||||
assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2"))
|
||||
assert.Equal(t, p.RemotePort, uint16(6))
|
||||
assert.Equal(t, p.LocalPort, uint16(5))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(2), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr)
|
||||
assert.Equal(t, uint16(6), p.RemotePort)
|
||||
assert.Equal(t, uint16(5), p.LocalPort)
|
||||
assert.False(t, p.Fragment)
|
||||
}
|
||||
|
||||
func Test_newPacket_v6(t *testing.T) {
|
||||
p := &firewall.Packet{}
|
||||
|
||||
// invalid ipv6
|
||||
ip := layers.IPv6{
|
||||
Version: 6,
|
||||
HopLimit: 128,
|
||||
SrcIP: net.IPv6linklocalallrouters,
|
||||
DstIP: net.IPv6linklocalallnodes,
|
||||
}
|
||||
|
||||
buffer := gopacket.NewSerializeBuffer()
|
||||
opt := gopacket.SerializeOptions{
|
||||
ComputeChecksums: false,
|
||||
FixLengths: false,
|
||||
}
|
||||
err := gopacket.SerializeLayers(buffer, opt, &ip)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = newPacket(buffer.Bytes(), true, p)
|
||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||
|
||||
// A good ICMP packet
|
||||
ip = layers.IPv6{
|
||||
Version: 6,
|
||||
NextHeader: layers.IPProtocolICMPv6,
|
||||
HopLimit: 128,
|
||||
SrcIP: net.IPv6linklocalallrouters,
|
||||
DstIP: net.IPv6linklocalallnodes,
|
||||
}
|
||||
|
||||
icmp := layers.ICMPv6{}
|
||||
|
||||
buffer.Clear()
|
||||
err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = newPacket(buffer.Bytes(), true, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||
assert.Equal(t, uint16(0), p.RemotePort)
|
||||
assert.Equal(t, uint16(0), p.LocalPort)
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// A good ESP packet
|
||||
b := buffer.Bytes()
|
||||
b[6] = byte(layers.IPProtocolESP)
|
||||
err = newPacket(b, true, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||
assert.Equal(t, uint16(0), p.RemotePort)
|
||||
assert.Equal(t, uint16(0), p.LocalPort)
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// A good None packet
|
||||
b = buffer.Bytes()
|
||||
b[6] = byte(layers.IPProtocolNoNextHeader)
|
||||
err = newPacket(b, true, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||
assert.Equal(t, uint16(0), p.RemotePort)
|
||||
assert.Equal(t, uint16(0), p.LocalPort)
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// An unknown protocol packet
|
||||
b = buffer.Bytes()
|
||||
b[6] = 255 // 255 is a reserved protocol number
|
||||
err = newPacket(b, true, p)
|
||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||
|
||||
// A good UDP packet
|
||||
ip = layers.IPv6{
|
||||
Version: 6,
|
||||
NextHeader: firewall.ProtoUDP,
|
||||
HopLimit: 128,
|
||||
SrcIP: net.IPv6linklocalallrouters,
|
||||
DstIP: net.IPv6linklocalallnodes,
|
||||
}
|
||||
|
||||
udp := layers.UDP{
|
||||
SrcPort: layers.UDPPort(36123),
|
||||
DstPort: layers.UDPPort(22),
|
||||
}
|
||||
err = udp.SetNetworkLayerForChecksum(&ip)
|
||||
require.NoError(t, err)
|
||||
|
||||
buffer.Clear()
|
||||
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
b = buffer.Bytes()
|
||||
|
||||
// incoming
|
||||
err = newPacket(b, true, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||
assert.Equal(t, uint16(36123), p.RemotePort)
|
||||
assert.Equal(t, uint16(22), p.LocalPort)
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// outgoing
|
||||
err = newPacket(b, false, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||
assert.Equal(t, uint16(36123), p.LocalPort)
|
||||
assert.Equal(t, uint16(22), p.RemotePort)
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// Too short UDP packet
|
||||
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||
|
||||
// A good TCP packet
|
||||
b[6] = byte(layers.IPProtocolTCP)
|
||||
|
||||
// incoming
|
||||
err = newPacket(b, true, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||
assert.Equal(t, uint16(36123), p.RemotePort)
|
||||
assert.Equal(t, uint16(22), p.LocalPort)
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// outgoing
|
||||
err = newPacket(b, false, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||
assert.Equal(t, uint16(36123), p.LocalPort)
|
||||
assert.Equal(t, uint16(22), p.RemotePort)
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// Too short TCP packet
|
||||
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||
|
||||
// A good UDP packet with an AH header
|
||||
ip = layers.IPv6{
|
||||
Version: 6,
|
||||
NextHeader: layers.IPProtocolAH,
|
||||
HopLimit: 128,
|
||||
SrcIP: net.IPv6linklocalallrouters,
|
||||
DstIP: net.IPv6linklocalallnodes,
|
||||
}
|
||||
|
||||
ah := layers.IPSecAH{
|
||||
AuthenticationData: []byte{0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef},
|
||||
}
|
||||
ah.NextHeader = layers.IPProtocolUDP
|
||||
|
||||
udpHeader := []byte{
|
||||
0x8d, 0x1b, // Source port 36123
|
||||
0x00, 0x16, // Destination port 22
|
||||
0x00, 0x00, // Length
|
||||
0x00, 0x00, // Checksum
|
||||
}
|
||||
|
||||
buffer.Clear()
|
||||
err = ip.SerializeTo(buffer, opt)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
b = buffer.Bytes()
|
||||
ahb := serializeAH(&ah)
|
||||
b = append(b, ahb...)
|
||||
b = append(b, udpHeader...)
|
||||
|
||||
err = newPacket(b, true, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||
assert.Equal(t, uint16(36123), p.RemotePort)
|
||||
assert.Equal(t, uint16(22), p.LocalPort)
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// Invalid AH header
|
||||
b = buffer.Bytes()
|
||||
err = newPacket(b, true, p)
|
||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||
}
|
||||
|
||||
func Test_newPacket_ipv6Fragment(t *testing.T) {
|
||||
p := &firewall.Packet{}
|
||||
|
||||
ip := &layers.IPv6{
|
||||
Version: 6,
|
||||
NextHeader: layers.IPProtocolIPv6Fragment,
|
||||
HopLimit: 64,
|
||||
SrcIP: net.IPv6linklocalallrouters,
|
||||
DstIP: net.IPv6linklocalallnodes,
|
||||
}
|
||||
|
||||
// First fragment
|
||||
fragHeader1 := []byte{
|
||||
uint8(layers.IPProtocolUDP), // Next Header (UDP)
|
||||
0x00, // Reserved
|
||||
0x00, // Fragment Offset high byte (0)
|
||||
0x01, // Fragment Offset low byte & flags (M=1)
|
||||
0x00, 0x00, 0x00, 0x01, // Identification
|
||||
}
|
||||
|
||||
udpHeader := []byte{
|
||||
0x8d, 0x1b, // Source port 36123
|
||||
0x00, 0x16, // Destination port 22
|
||||
0x00, 0x00, // Length
|
||||
0x00, 0x00, // Checksum
|
||||
}
|
||||
|
||||
buffer := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
err := ip.SerializeTo(buffer, opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
firstFrag := buffer.Bytes()
|
||||
firstFrag = append(firstFrag, fragHeader1...)
|
||||
firstFrag = append(firstFrag, udpHeader...)
|
||||
firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
|
||||
|
||||
// Test first fragment incoming
|
||||
err = newPacket(firstFrag, true, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
||||
assert.Equal(t, uint16(36123), p.RemotePort)
|
||||
assert.Equal(t, uint16(22), p.LocalPort)
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// Test first fragment outgoing
|
||||
err = newPacket(firstFrag, false, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
||||
assert.Equal(t, uint16(36123), p.LocalPort)
|
||||
assert.Equal(t, uint16(22), p.RemotePort)
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// Second fragment
|
||||
fragHeader2 := []byte{
|
||||
uint8(layers.IPProtocolUDP), // Next Header (UDP)
|
||||
0x00, // Reserved
|
||||
0xb9, // Fragment Offset high byte (185)
|
||||
0x01, // Fragment Offset low byte & flags (M=1)
|
||||
0x00, 0x00, 0x00, 0x01, // Identification
|
||||
}
|
||||
|
||||
buffer.Clear()
|
||||
err = ip.SerializeTo(buffer, opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secondFrag := buffer.Bytes()
|
||||
secondFrag = append(secondFrag, fragHeader2...)
|
||||
secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
|
||||
|
||||
// Test second fragment incoming
|
||||
err = newPacket(secondFrag, true, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
||||
assert.Equal(t, uint16(0), p.RemotePort)
|
||||
assert.Equal(t, uint16(0), p.LocalPort)
|
||||
assert.True(t, p.Fragment)
|
||||
|
||||
// Test second fragment outgoing
|
||||
err = newPacket(secondFrag, false, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
||||
assert.Equal(t, uint16(0), p.LocalPort)
|
||||
assert.Equal(t, uint16(0), p.RemotePort)
|
||||
assert.True(t, p.Fragment)
|
||||
|
||||
// Too short of a fragment packet
|
||||
err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
|
||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||
}
|
||||
|
||||
func BenchmarkParseV6(b *testing.B) {
|
||||
// Regular UDP packet
|
||||
ip := &layers.IPv6{
|
||||
Version: 6,
|
||||
NextHeader: layers.IPProtocolUDP,
|
||||
HopLimit: 64,
|
||||
SrcIP: net.IPv6linklocalallrouters,
|
||||
DstIP: net.IPv6linklocalallnodes,
|
||||
}
|
||||
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(36123),
|
||||
DstPort: layers.UDPPort(22),
|
||||
}
|
||||
|
||||
buffer := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{
|
||||
ComputeChecksums: false,
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
err := gopacket.SerializeLayers(buffer, opts, ip, udp)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
normalPacket := buffer.Bytes()
|
||||
|
||||
// First Fragment packet
|
||||
ipFrag := &layers.IPv6{
|
||||
Version: 6,
|
||||
NextHeader: layers.IPProtocolIPv6Fragment,
|
||||
HopLimit: 64,
|
||||
SrcIP: net.IPv6linklocalallrouters,
|
||||
DstIP: net.IPv6linklocalallnodes,
|
||||
}
|
||||
|
||||
fragHeader := []byte{
|
||||
uint8(layers.IPProtocolUDP), // Next Header (UDP)
|
||||
0x00, // Reserved
|
||||
0x00, // Fragment Offset high byte (0)
|
||||
0x01, // Fragment Offset low byte & flags (M=1)
|
||||
0x00, 0x00, 0x00, 0x01, // Identification
|
||||
}
|
||||
|
||||
udpHeader := []byte{
|
||||
0x8d, 0x7b, // Source port 36123
|
||||
0x00, 0x16, // Destination port 22
|
||||
0x00, 0x00, // Length
|
||||
0x00, 0x00, // Checksum
|
||||
}
|
||||
|
||||
buffer.Clear()
|
||||
err = ipFrag.SerializeTo(buffer, opts)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
firstFrag := buffer.Bytes()
|
||||
firstFrag = append(firstFrag, fragHeader...)
|
||||
firstFrag = append(firstFrag, udpHeader...)
|
||||
firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
|
||||
|
||||
// Second Fragment packet
|
||||
fragHeader[2] = 0xb9 // offset 185
|
||||
buffer.Clear()
|
||||
err = ipFrag.SerializeTo(buffer, opts)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
secondFrag := buffer.Bytes()
|
||||
secondFrag = append(secondFrag, fragHeader...)
|
||||
secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
|
||||
|
||||
fp := &firewall.Packet{}
|
||||
|
||||
b.Run("Normal", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err = parseV6(normalPacket, true, fp); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("FirstFragment", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err = parseV6(firstFrag, true, fp); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("SecondFragment", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err = parseV6(secondFrag, true, fp); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Evil packet
|
||||
evilPacket := &layers.IPv6{
|
||||
Version: 6,
|
||||
NextHeader: layers.IPProtocolIPv6HopByHop,
|
||||
HopLimit: 64,
|
||||
SrcIP: net.IPv6linklocalallrouters,
|
||||
DstIP: net.IPv6linklocalallnodes,
|
||||
}
|
||||
|
||||
hopHeader := []byte{
|
||||
uint8(layers.IPProtocolIPv6HopByHop), // Next Header (HopByHop)
|
||||
0x00, // Length
|
||||
0x00, 0x00, // Options and padding
|
||||
0x00, 0x00, 0x00, 0x00, // More options and padding
|
||||
}
|
||||
|
||||
lastHopHeader := []byte{
|
||||
uint8(layers.IPProtocolUDP), // Next Header (UDP)
|
||||
0x00, // Length
|
||||
0x00, 0x00, // Options and padding
|
||||
0x00, 0x00, 0x00, 0x00, // More options and padding
|
||||
}
|
||||
|
||||
buffer.Clear()
|
||||
err = evilPacket.SerializeTo(buffer, opts)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
evilBytes := buffer.Bytes()
|
||||
for i := 0; i < 200; i++ {
|
||||
evilBytes = append(evilBytes, hopHeader...)
|
||||
}
|
||||
evilBytes = append(evilBytes, lastHopHeader...)
|
||||
evilBytes = append(evilBytes, udpHeader...)
|
||||
evilBytes = append(evilBytes, []byte{0xde, 0xad, 0xbe, 0xef}...)
|
||||
|
||||
b.Run("200 HopByHop headers", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err = parseV6(evilBytes, false, fp); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Ensure authentication data is a multiple of 8 bytes by padding if necessary
|
||||
func padAuthData(authData []byte) []byte {
|
||||
// Length of Authentication Data must be a multiple of 8 bytes
|
||||
paddingLength := (8 - (len(authData) % 8)) % 8 // Only pad if necessary
|
||||
if paddingLength > 0 {
|
||||
authData = append(authData, make([]byte, paddingLength)...)
|
||||
}
|
||||
return authData
|
||||
}
|
||||
|
||||
// Custom function to manually serialize IPSecAH for both IPv4 and IPv6
|
||||
func serializeAH(ah *layers.IPSecAH) []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
// Ensure Authentication Data is a multiple of 8 bytes
|
||||
ah.AuthenticationData = padAuthData(ah.AuthenticationData)
|
||||
// Calculate Payload Length (in 32-bit words, minus 2)
|
||||
payloadLen := uint8((12+len(ah.AuthenticationData))/4) - 2
|
||||
|
||||
// Serialize fields
|
||||
if err := binary.Write(buf, binary.BigEndian, ah.NextHeader); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := binary.Write(buf, binary.BigEndian, payloadLen); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := binary.Write(buf, binary.BigEndian, ah.Reserved); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := binary.Write(buf, binary.BigEndian, ah.SPI); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := binary.Write(buf, binary.BigEndian, ah.Seq); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if len(ah.AuthenticationData) > 0 {
|
||||
if err := binary.Write(buf, binary.BigEndian, ah.AuthenticationData); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
@ -8,7 +8,7 @@ import (
|
||||
type Device interface {
|
||||
io.ReadWriteCloser
|
||||
Activate() error
|
||||
Cidr() netip.Prefix
|
||||
Networks() []netip.Prefix
|
||||
Name() string
|
||||
RouteFor(netip.Addr) netip.Addr
|
||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||
|
||||
@ -61,7 +61,7 @@ func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table
|
||||
return routeTree, nil
|
||||
}
|
||||
|
||||
func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
|
||||
func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
||||
var err error
|
||||
|
||||
r := c.Get("tun.routes")
|
||||
@ -117,12 +117,20 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
|
||||
return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
|
||||
}
|
||||
|
||||
if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() {
|
||||
found := false
|
||||
for _, network := range networks {
|
||||
if network.Contains(r.Cidr.Addr()) && r.Cidr.Bits() >= network.Bits() {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, fmt.Errorf(
|
||||
"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
|
||||
"entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, networks: %v",
|
||||
i+1,
|
||||
r.Cidr.String(),
|
||||
network.String(),
|
||||
networks,
|
||||
)
|
||||
}
|
||||
|
||||
@ -132,7 +140,7 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
|
||||
func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
||||
var err error
|
||||
|
||||
r := c.Get("tun.unsafe_routes")
|
||||
@ -229,13 +237,15 @@ func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
|
||||
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
|
||||
}
|
||||
|
||||
if network.Contains(r.Cidr.Addr()) {
|
||||
return nil, fmt.Errorf(
|
||||
"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
|
||||
i+1,
|
||||
r.Cidr.String(),
|
||||
network.String(),
|
||||
)
|
||||
for _, network := range networks {
|
||||
if network.Contains(r.Cidr.Addr()) {
|
||||
return nil, fmt.Errorf(
|
||||
"entry %v.route in tun.unsafe_routes is contained within the configured vpn networks; route: %v, network: %v",
|
||||
i+1,
|
||||
r.Cidr.String(),
|
||||
network.String(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
routes[i] = r
|
||||
|
||||
@ -8,86 +8,93 @@ import (
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_parseRoutes(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
n, err := netip.ParsePrefix("10.0.0.0/24")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// test no routes config
|
||||
routes, err := parseRoutes(c, n)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, routes, 0)
|
||||
routes, err := parseRoutes(c, []netip.Prefix{n})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, routes)
|
||||
|
||||
// not an array
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
|
||||
routes, err = parseRoutes(c, n)
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "tun.routes is not an array")
|
||||
require.EqualError(t, err, "tun.routes is not an array")
|
||||
|
||||
// no routes
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
|
||||
routes, err = parseRoutes(c, n)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, routes, 0)
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, routes)
|
||||
|
||||
// weird route
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
|
||||
routes, err = parseRoutes(c, n)
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
|
||||
require.EqualError(t, err, "entry 1 in tun.routes is invalid")
|
||||
|
||||
// no mtu
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
|
||||
routes, err = parseRoutes(c, n)
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
|
||||
require.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
|
||||
|
||||
// bad mtu
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
|
||||
routes, err = parseRoutes(c, n)
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
||||
require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
||||
|
||||
// low mtu
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
|
||||
routes, err = parseRoutes(c, n)
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
|
||||
require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
|
||||
|
||||
// missing route
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
|
||||
routes, err = parseRoutes(c, n)
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
|
||||
require.EqualError(t, err, "entry 1.route in tun.routes is not present")
|
||||
|
||||
// unparsable route
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
|
||||
routes, err = parseRoutes(c, n)
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
||||
require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
||||
|
||||
// below network range
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
|
||||
routes, err = parseRoutes(c, n)
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24")
|
||||
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
|
||||
|
||||
// above network range
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
|
||||
routes, err = parseRoutes(c, n)
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24")
|
||||
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
|
||||
|
||||
// Not in multiple ranges
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
|
||||
|
||||
// happy case
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
|
||||
map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
|
||||
map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
|
||||
}}
|
||||
routes, err = parseRoutes(c, n)
|
||||
assert.Nil(t, err)
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, routes, 2)
|
||||
|
||||
tested := 0
|
||||
@ -113,106 +120,106 @@ func Test_parseUnsafeRoutes(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
n, err := netip.ParsePrefix("10.0.0.0/24")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// test no routes config
|
||||
routes, err := parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, routes, 0)
|
||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, routes)
|
||||
|
||||
// not an array
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "tun.unsafe_routes is not an array")
|
||||
require.EqualError(t, err, "tun.unsafe_routes is not an array")
|
||||
|
||||
// no routes
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, routes, 0)
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, routes)
|
||||
|
||||
// weird route
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
|
||||
require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
|
||||
|
||||
// no via
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
|
||||
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
|
||||
|
||||
// invalid via
|
||||
for _, invalidValue := range []interface{}{
|
||||
127, false, nil, 1.0, []string{"1", "2"},
|
||||
} {
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
|
||||
require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
|
||||
}
|
||||
|
||||
// unparsable via
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
|
||||
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
|
||||
|
||||
// missing route
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
|
||||
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
|
||||
|
||||
// unparsable route
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
||||
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
||||
|
||||
// within network range
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24")
|
||||
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
|
||||
|
||||
// below network range
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Len(t, routes, 1)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// above network range
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Len(t, routes, 1)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// no mtu
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Len(t, routes, 1)
|
||||
assert.Equal(t, 0, routes[0].MTU)
|
||||
|
||||
// bad mtu
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
||||
require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
||||
|
||||
// low mtu
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
|
||||
require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
|
||||
|
||||
// bad install
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
|
||||
require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
|
||||
|
||||
// happy case
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
|
||||
@ -221,8 +228,8 @@ func Test_parseUnsafeRoutes(t *testing.T) {
|
||||
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
|
||||
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
|
||||
}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, err)
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, routes, 4)
|
||||
|
||||
tested := 0
|
||||
@ -254,38 +261,38 @@ func Test_makeRouteTree(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
n, err := netip.ParsePrefix("10.0.0.0/24")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
|
||||
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
|
||||
map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
|
||||
}}
|
||||
routes, err := parseUnsafeRoutes(c, n)
|
||||
assert.NoError(t, err)
|
||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, routes, 2)
|
||||
routeTree, err := makeRouteTree(l, routes, true)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
ip, err := netip.ParseAddr("1.0.0.2")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
r, ok := routeTree.Lookup(ip)
|
||||
assert.True(t, ok)
|
||||
|
||||
nip, err := netip.ParseAddr("192.168.0.1")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, nip, r)
|
||||
|
||||
ip, err = netip.ParseAddr("1.0.0.1")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
r, ok = routeTree.Lookup(ip)
|
||||
assert.True(t, ok)
|
||||
|
||||
nip, err = netip.ParseAddr("192.168.0.2")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, nip, r)
|
||||
|
||||
ip, err = netip.ParseAddr("1.1.0.1")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
r, ok = routeTree.Lookup(ip)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
@ -11,36 +11,36 @@ import (
|
||||
const DefaultMTU = 1300
|
||||
|
||||
// TODO: We may be able to remove routines
|
||||
type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error)
|
||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||
|
||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
|
||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
switch {
|
||||
case c.GetBool("tun.disabled", false):
|
||||
tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||
return tun, nil
|
||||
|
||||
default:
|
||||
return newTun(c, l, tunCidr, routines > 1)
|
||||
return newTun(c, l, vpnNetworks, routines > 1)
|
||||
}
|
||||
}
|
||||
|
||||
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
||||
return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
|
||||
return newTunFromFd(c, l, *fd, tunCidr)
|
||||
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
return newTunFromFd(c, l, *fd, vpnNetworks)
|
||||
}
|
||||
}
|
||||
|
||||
func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) {
|
||||
func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
|
||||
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
routes, err := parseRoutes(c, cidr)
|
||||
routes, err := parseRoutes(c, vpnNetworks)
|
||||
if err != nil {
|
||||
return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
|
||||
}
|
||||
|
||||
unsafeRoutes, err := parseUnsafeRoutes(c, cidr)
|
||||
unsafeRoutes, err := parseUnsafeRoutes(c, vpnNetworks)
|
||||
if err != nil {
|
||||
return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
||||
}
|
||||
|
||||
@ -18,14 +18,14 @@ import (
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
fd int
|
||||
cidr netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
l *logrus.Logger
|
||||
fd int
|
||||
vpnNetworks []netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
|
||||
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
@ -33,7 +33,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
fd: deviceFd,
|
||||
cidr: cidr,
|
||||
vpnNetworks: vpnNetworks,
|
||||
l: l,
|
||||
}
|
||||
|
||||
@ -52,7 +52,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTun not supported in Android")
|
||||
}
|
||||
|
||||
@ -66,7 +66,7 @@ func (t tun) Activate() error {
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -86,8 +86,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *tun) Name() string {
|
||||
|
||||
@ -24,56 +24,62 @@ import (
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
Device string
|
||||
cidr netip.Prefix
|
||||
DefaultMTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
linkAddr *netroute.LinkAddr
|
||||
l *logrus.Logger
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
DefaultMTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
linkAddr *netroute.LinkAddr
|
||||
l *logrus.Logger
|
||||
|
||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||
out []byte
|
||||
}
|
||||
|
||||
type sockaddrCtl struct {
|
||||
scLen uint8
|
||||
scFamily uint8
|
||||
ssSysaddr uint16
|
||||
scID uint32
|
||||
scUnit uint32
|
||||
scReserved [5]uint32
|
||||
}
|
||||
|
||||
type ifReq struct {
|
||||
Name [16]byte
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Flags uint16
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
var sockaddrCtlSize uintptr = 32
|
||||
|
||||
const (
|
||||
_SYSPROTO_CONTROL = 2 //define SYSPROTO_CONTROL 2 /* kernel control protocol */
|
||||
_AF_SYS_CONTROL = 2 //#define AF_SYS_CONTROL 2 /* corresponding sub address type */
|
||||
_PF_SYSTEM = unix.AF_SYSTEM //#define PF_SYSTEM AF_SYSTEM
|
||||
_CTLIOCGINFO = 3227799043 //#define CTLIOCGINFO _IOWR('N', 3, struct ctl_info)
|
||||
utunControlName = "com.apple.net.utun_control"
|
||||
_SIOCAIFADDR_IN6 = 2155899162
|
||||
_UTUN_OPT_IFNAME = 2
|
||||
_IN6_IFF_NODAD = 0x0020
|
||||
_IN6_IFF_SECURED = 0x0400
|
||||
utunControlName = "com.apple.net.utun_control"
|
||||
)
|
||||
|
||||
type ifreqAddr struct {
|
||||
Name [16]byte
|
||||
Addr unix.RawSockaddrInet4
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
type ifreqMTU struct {
|
||||
Name [16]byte
|
||||
MTU int32
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
|
||||
type addrLifetime struct {
|
||||
Expire float64
|
||||
Preferred float64
|
||||
Vltime uint32
|
||||
Pltime uint32
|
||||
}
|
||||
|
||||
type ifreqAlias4 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet4
|
||||
DstAddr unix.RawSockaddrInet4
|
||||
MaskAddr unix.RawSockaddrInet4
|
||||
}
|
||||
|
||||
type ifreqAlias6 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet6
|
||||
DstAddr unix.RawSockaddrInet6
|
||||
PrefixMask unix.RawSockaddrInet6
|
||||
Flags uint32
|
||||
Lifetime addrLifetime
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
name := c.GetString("tun.dev", "")
|
||||
ifIndex := -1
|
||||
if name != "" && name != "utun" {
|
||||
@ -86,66 +92,41 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
|
||||
}
|
||||
}
|
||||
|
||||
fd, err := unix.Socket(_PF_SYSTEM, unix.SOCK_DGRAM, _SYSPROTO_CONTROL)
|
||||
fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("system socket: %v", err)
|
||||
}
|
||||
|
||||
var ctlInfo = &struct {
|
||||
ctlID uint32
|
||||
ctlName [96]byte
|
||||
}{}
|
||||
var ctlInfo = &unix.CtlInfo{}
|
||||
copy(ctlInfo.Name[:], utunControlName)
|
||||
|
||||
copy(ctlInfo.ctlName[:], utunControlName)
|
||||
|
||||
err = ioctl(uintptr(fd), uintptr(_CTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo)))
|
||||
err = unix.IoctlCtlInfo(fd, ctlInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
|
||||
}
|
||||
|
||||
sc := sockaddrCtl{
|
||||
scLen: uint8(sockaddrCtlSize),
|
||||
scFamily: unix.AF_SYSTEM,
|
||||
ssSysaddr: _AF_SYS_CONTROL,
|
||||
scID: ctlInfo.ctlID,
|
||||
scUnit: uint32(ifIndex) + 1,
|
||||
err = unix.Connect(fd, &unix.SockaddrCtl{
|
||||
ID: ctlInfo.Id,
|
||||
Unit: uint32(ifIndex) + 1,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SYS_CONNECT: %v", err)
|
||||
}
|
||||
|
||||
_, _, errno := unix.RawSyscall(
|
||||
unix.SYS_CONNECT,
|
||||
uintptr(fd),
|
||||
uintptr(unsafe.Pointer(&sc)),
|
||||
sockaddrCtlSize,
|
||||
)
|
||||
if errno != 0 {
|
||||
return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
|
||||
name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve tun name: %w", err)
|
||||
}
|
||||
|
||||
var ifName struct {
|
||||
name [16]byte
|
||||
}
|
||||
ifNameSize := uintptr(len(ifName.name))
|
||||
_, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd),
|
||||
2, // SYSPROTO_CONTROL
|
||||
2, // UTUN_OPT_IFNAME
|
||||
uintptr(unsafe.Pointer(&ifName)),
|
||||
uintptr(unsafe.Pointer(&ifNameSize)), 0)
|
||||
if errno != 0 {
|
||||
return nil, fmt.Errorf("SYS_GETSOCKOPT: %v", errno)
|
||||
}
|
||||
name = string(ifName.name[:ifNameSize-1])
|
||||
|
||||
err = syscall.SetNonblock(fd, true)
|
||||
err = unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SetNonblock: %v", err)
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "")
|
||||
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
ReadWriteCloser: os.NewFile(uintptr(fd), ""),
|
||||
Device: name,
|
||||
cidr: cidr,
|
||||
vpnNetworks: vpnNetworks,
|
||||
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
@ -172,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
||||
return
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||
}
|
||||
|
||||
@ -186,16 +167,6 @@ func (t *tun) Close() error {
|
||||
func (t *tun) Activate() error {
|
||||
devName := t.deviceBytes()
|
||||
|
||||
var addr, mask [4]byte
|
||||
|
||||
if !t.cidr.Addr().Is4() {
|
||||
//TODO: IPV6-WORK
|
||||
panic("need ipv6")
|
||||
}
|
||||
|
||||
addr = t.cidr.Addr().As4()
|
||||
copy(mask[:], prefixToMask(t.cidr))
|
||||
|
||||
s, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM,
|
||||
@ -208,66 +179,18 @@ func (t *tun) Activate() error {
|
||||
|
||||
fd := uintptr(s)
|
||||
|
||||
ifra := ifreqAddr{
|
||||
Name: devName,
|
||||
Addr: unix.RawSockaddrInet4{
|
||||
Family: unix.AF_INET,
|
||||
Addr: addr,
|
||||
},
|
||||
}
|
||||
|
||||
// Set the device ip address
|
||||
if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address: %s", err)
|
||||
}
|
||||
|
||||
// Set the device network
|
||||
ifra.Addr.Addr = mask
|
||||
if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
|
||||
return fmt.Errorf("failed to set tun netmask: %s", err)
|
||||
}
|
||||
|
||||
// Set the device name
|
||||
ifrf := ifReq{Name: devName}
|
||||
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
return fmt.Errorf("failed to set tun device name: %s", err)
|
||||
}
|
||||
|
||||
// Set the MTU on the device
|
||||
ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
|
||||
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
||||
return fmt.Errorf("failed to set tun mtu: %v", err)
|
||||
}
|
||||
|
||||
/*
|
||||
// Set the transmit queue length
|
||||
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
|
||||
if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
||||
// If we can't set the queue length nebula will still work but it may lead to packet loss
|
||||
l.WithError(err).Error("Failed to set tun tx queue length")
|
||||
}
|
||||
*/
|
||||
|
||||
// Bring up the interface
|
||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP
|
||||
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
return fmt.Errorf("failed to bring the tun device up: %s", err)
|
||||
// Get the device flags
|
||||
ifrf := ifReq{Name: devName}
|
||||
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
return fmt.Errorf("failed to get tun flags: %s", err)
|
||||
}
|
||||
|
||||
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
unix.Shutdown(routeSock, unix.SHUT_RDWR)
|
||||
err := unix.Close(routeSock)
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
|
||||
}
|
||||
}()
|
||||
|
||||
routeAddr := &netroute.Inet4Addr{}
|
||||
maskAddr := &netroute.Inet4Addr{}
|
||||
linkAddr, err := getLinkAddr(t.Device)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -277,14 +200,18 @@ func (t *tun) Activate() error {
|
||||
}
|
||||
t.linkAddr = linkAddr
|
||||
|
||||
copy(routeAddr.IP[:], addr[:])
|
||||
copy(maskAddr.IP[:], mask[:])
|
||||
err = addRoute(routeSock, routeAddr, maskAddr, linkAddr)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
err = fmt.Errorf("unable to add tun route, identical route already exists: %s", t.cidr)
|
||||
for _, network := range t.vpnNetworks {
|
||||
if network.Addr().Is4() {
|
||||
err = t.activate4(network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err = t.activate6(network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Run the interface
|
||||
@ -297,8 +224,89 @@ func (t *tun) Activate() error {
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) activate4(network netip.Prefix) error {
|
||||
s, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM,
|
||||
unix.IPPROTO_IP,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer unix.Close(s)
|
||||
|
||||
ifr := ifreqAlias4{
|
||||
Name: t.deviceBytes(),
|
||||
Addr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: network.Addr().As4(),
|
||||
},
|
||||
DstAddr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: network.Addr().As4(),
|
||||
},
|
||||
MaskAddr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: prefixToMask(network).As4(),
|
||||
},
|
||||
}
|
||||
|
||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||
return fmt.Errorf("failed to set tun v4 address: %s", err)
|
||||
}
|
||||
|
||||
err = addRoute(network, t.linkAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) activate6(network netip.Prefix) error {
|
||||
s, err := unix.Socket(
|
||||
unix.AF_INET6,
|
||||
unix.SOCK_DGRAM,
|
||||
unix.IPPROTO_IP,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer unix.Close(s)
|
||||
|
||||
ifr := ifreqAlias6{
|
||||
Name: t.deviceBytes(),
|
||||
Addr: unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: network.Addr().As16(),
|
||||
},
|
||||
PrefixMask: unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: prefixToMask(network).As16(),
|
||||
},
|
||||
Lifetime: addrLifetime{
|
||||
// never expires
|
||||
Vltime: 0xffffffff,
|
||||
Pltime: 0xffffffff,
|
||||
},
|
||||
//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
|
||||
Flags: _IN6_IFF_NODAD,
|
||||
}
|
||||
|
||||
if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -343,7 +351,7 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
}
|
||||
|
||||
// Get the LinkAddr for the interface of the given name
|
||||
// TODO: Is there an easier way to fetch this when we create the interface?
|
||||
// Is there an easier way to fetch this when we create the interface?
|
||||
// Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers.
|
||||
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
||||
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
|
||||
@ -371,38 +379,15 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
unix.Shutdown(routeSock, unix.SHUT_RDWR)
|
||||
err := unix.Close(routeSock)
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
|
||||
}
|
||||
}()
|
||||
|
||||
routeAddr := &netroute.Inet4Addr{}
|
||||
maskAddr := &netroute.Inet4Addr{}
|
||||
routes := *t.Routes.Load()
|
||||
|
||||
for _, r := range routes {
|
||||
if !r.Via.IsValid() || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
if !r.Cidr.Addr().Is4() {
|
||||
//TODO: implement ipv6
|
||||
panic("Cant handle ipv6 routes yet")
|
||||
}
|
||||
|
||||
routeAddr.IP = r.Cidr.Addr().As4()
|
||||
//TODO: we could avoid the copy
|
||||
copy(maskAddr.IP[:], prefixToMask(r.Cidr))
|
||||
|
||||
err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
|
||||
err := addRoute(r.Cidr, t.linkAddr)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
t.l.WithField("route", r.Cidr).
|
||||
@ -424,36 +409,12 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
}
|
||||
|
||||
func (t *tun) removeRoutes(routes []Route) error {
|
||||
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
unix.Shutdown(routeSock, unix.SHUT_RDWR)
|
||||
err := unix.Close(routeSock)
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
|
||||
}
|
||||
}()
|
||||
|
||||
routeAddr := &netroute.Inet4Addr{}
|
||||
maskAddr := &netroute.Inet4Addr{}
|
||||
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
if r.Cidr.Addr().Is6() {
|
||||
//TODO: implement ipv6
|
||||
panic("Cant handle ipv6 routes yet")
|
||||
}
|
||||
|
||||
routeAddr.IP = r.Cidr.Addr().As4()
|
||||
copy(maskAddr.IP[:], prefixToMask(r.Cidr))
|
||||
|
||||
err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
|
||||
err := delRoute(r.Cidr, t.linkAddr)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
@ -463,23 +424,39 @@ func (t *tun) removeRoutes(routes []Route) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
|
||||
r := netroute.RouteMessage{
|
||||
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := &netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_ADD,
|
||||
Flags: unix.RTF_UP,
|
||||
Seq: 1,
|
||||
Addrs: []netroute.Addr{
|
||||
unix.RTAX_DST: addr,
|
||||
unix.RTAX_GATEWAY: link,
|
||||
unix.RTAX_NETMASK: mask,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := r.Marshal()
|
||||
if prefix.Addr().Is4() {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
} else {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
@ -488,19 +465,34 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
|
||||
r := netroute.RouteMessage{
|
||||
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_DELETE,
|
||||
Seq: 1,
|
||||
Addrs: []netroute.Addr{
|
||||
unix.RTAX_DST: addr,
|
||||
unix.RTAX_GATEWAY: link,
|
||||
unix.RTAX_NETMASK: mask,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := r.Marshal()
|
||||
if prefix.Addr().Is4() {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
} else {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
@ -513,7 +505,6 @@ func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
|
||||
buf := make([]byte, len(to)+4)
|
||||
|
||||
n, err := t.ReadWriteCloser.Read(buf)
|
||||
@ -551,8 +542,8 @@ func (t *tun) Write(from []byte) (int, error) {
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *tun) Name() string {
|
||||
@ -563,10 +554,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||
}
|
||||
|
||||
func prefixToMask(prefix netip.Prefix) []byte {
|
||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||
pLen := 128
|
||||
if prefix.Addr().Is4() {
|
||||
pLen = 32
|
||||
}
|
||||
return net.CIDRMask(prefix.Bits(), pLen)
|
||||
|
||||
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
||||
return addr
|
||||
}
|
||||
|
||||
@ -12,8 +12,8 @@ import (
|
||||
)
|
||||
|
||||
type disabledTun struct {
|
||||
read chan []byte
|
||||
cidr netip.Prefix
|
||||
read chan []byte
|
||||
vpnNetworks []netip.Prefix
|
||||
|
||||
// Track these metrics since we don't have the tun device to do it for us
|
||||
tx metrics.Counter
|
||||
@ -21,11 +21,11 @@ type disabledTun struct {
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
||||
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
||||
tun := &disabledTun{
|
||||
cidr: cidr,
|
||||
read: make(chan []byte, queueLen),
|
||||
l: l,
|
||||
vpnNetworks: vpnNetworks,
|
||||
read: make(chan []byte, queueLen),
|
||||
l: l,
|
||||
}
|
||||
|
||||
if metricsEnabled {
|
||||
@ -47,8 +47,8 @@ func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (t *disabledTun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
func (t *disabledTun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (*disabledTun) Name() string {
|
||||
|
||||
@ -46,12 +46,12 @@ type ifreqDestroy struct {
|
||||
}
|
||||
|
||||
type tun struct {
|
||||
Device string
|
||||
cidr netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
l *logrus.Logger
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
@ -78,11 +78,11 @@ func (t *tun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open existing tun device
|
||||
var file *os.File
|
||||
var err error
|
||||
@ -150,7 +150,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
Device: deviceName,
|
||||
cidr: cidr,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
@ -170,16 +170,16 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||
var err error
|
||||
// TODO use syscalls instead of exec.Command
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
@ -195,8 +195,18 @@ func (t *tun) Activate() error {
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
for i := range t.vpnNetworks {
|
||||
err := t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -237,8 +247,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *tun) Name() string {
|
||||
|
||||
@ -21,20 +21,20 @@ import (
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
cidr netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
l *logrus.Logger
|
||||
vpnNetworks []netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTun not supported in iOS")
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||
t := &tun{
|
||||
cidr: cidr,
|
||||
vpnNetworks: vpnNetworks,
|
||||
ReadWriteCloser: &tunReadCloser{f: file},
|
||||
l: l,
|
||||
}
|
||||
@ -59,7 +59,7 @@ func (t *tun) Activate() error {
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -142,8 +142,8 @@ func (tr *tunReadCloser) Close() error {
|
||||
return tr.f.Close()
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *tun) Name() string {
|
||||
|
||||
@ -11,6 +11,7 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
@ -25,7 +26,7 @@ type tun struct {
|
||||
io.ReadWriteCloser
|
||||
fd int
|
||||
Device string
|
||||
cidr netip.Prefix
|
||||
vpnNetworks []netip.Prefix
|
||||
MaxMTU int
|
||||
DefaultMTU int
|
||||
TXQueueLen int
|
||||
@ -40,18 +41,16 @@ type tun struct {
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
type ifReq struct {
|
||||
Name [16]byte
|
||||
Flags uint16
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
type ifreqAddr struct {
|
||||
Name [16]byte
|
||||
Addr unix.RawSockaddrInet4
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
type ifreqMTU struct {
|
||||
Name [16]byte
|
||||
MTU int32
|
||||
@ -64,10 +63,10 @@ type ifreqQLEN struct {
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
t, err := newTunGeneric(c, l, file, cidr)
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -77,7 +76,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
||||
@ -112,7 +111,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
|
||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
t, err := newTunGeneric(c, l, file, cidr)
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -122,11 +121,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) {
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
fd: int(file.Fd()),
|
||||
cidr: cidr,
|
||||
vpnNetworks: vpnNetworks,
|
||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||
l: l,
|
||||
@ -148,7 +147,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Pref
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
||||
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -190,11 +189,13 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
||||
}
|
||||
|
||||
if oldDefaultMTU != newDefaultMTU {
|
||||
err := t.setDefaultRoute()
|
||||
if err != nil {
|
||||
t.l.Warn(err)
|
||||
} else {
|
||||
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
|
||||
for i := range t.vpnNetworks {
|
||||
err := t.setDefaultRoute(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
t.l.Warn(err)
|
||||
} else {
|
||||
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -237,10 +238,10 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
|
||||
func (t *tun) Write(b []byte) (int, error) {
|
||||
var nn int
|
||||
max := len(b)
|
||||
maximum := len(b)
|
||||
|
||||
for {
|
||||
n, err := unix.Write(t.fd, b[nn:max])
|
||||
n, err := unix.Write(t.fd, b[nn:maximum])
|
||||
if n > 0 {
|
||||
nn += n
|
||||
}
|
||||
@ -265,6 +266,58 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
||||
return
|
||||
}
|
||||
|
||||
func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
|
||||
for i := range al {
|
||||
if al[i].Equal(x) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
|
||||
func (t *tun) addIPs(link netlink.Link) error {
|
||||
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
|
||||
for i := range t.vpnNetworks {
|
||||
newAddrs[i] = &netlink.Addr{
|
||||
IPNet: &net.IPNet{
|
||||
IP: t.vpnNetworks[i].Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
|
||||
},
|
||||
Label: t.vpnNetworks[i].Addr().Zone(),
|
||||
}
|
||||
}
|
||||
|
||||
//add all new addresses
|
||||
for i := range newAddrs {
|
||||
//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
|
||||
//AddrReplace still adds new IPs, but if their properties change it will change them as well
|
||||
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
//iterate over remainder, remove whoever shouldn't be there
|
||||
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get tun address list: %s", err)
|
||||
}
|
||||
|
||||
for i := range al {
|
||||
if hasNetlinkAddr(newAddrs, al[i]) {
|
||||
continue
|
||||
}
|
||||
err = netlink.AddrDel(link, &al[i])
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("failed to remove address from tun address list")
|
||||
} else {
|
||||
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
devName := t.deviceBytes()
|
||||
|
||||
@ -272,15 +325,8 @@ func (t *tun) Activate() error {
|
||||
t.watchRoutes()
|
||||
}
|
||||
|
||||
var addr, mask [4]byte
|
||||
|
||||
//TODO: IPV6-WORK
|
||||
addr = t.cidr.Addr().As4()
|
||||
tmask := net.CIDRMask(t.cidr.Bits(), 32)
|
||||
copy(mask[:], tmask)
|
||||
|
||||
s, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine
|
||||
unix.SOCK_DGRAM,
|
||||
unix.IPPROTO_IP,
|
||||
)
|
||||
@ -289,31 +335,19 @@ func (t *tun) Activate() error {
|
||||
}
|
||||
t.ioctlFd = uintptr(s)
|
||||
|
||||
ifra := ifreqAddr{
|
||||
Name: devName,
|
||||
Addr: unix.RawSockaddrInet4{
|
||||
Family: unix.AF_INET,
|
||||
Addr: addr,
|
||||
},
|
||||
}
|
||||
|
||||
// Set the device ip address
|
||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address: %s", err)
|
||||
}
|
||||
|
||||
// Set the device network
|
||||
ifra.Addr.Addr = mask
|
||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
|
||||
return fmt.Errorf("failed to set tun netmask: %s", err)
|
||||
}
|
||||
|
||||
// Set the device name
|
||||
ifrf := ifReq{Name: devName}
|
||||
if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
return fmt.Errorf("failed to set tun device name: %s", err)
|
||||
}
|
||||
|
||||
link, err := netlink.LinkByName(t.Device)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get tun device link: %s", err)
|
||||
}
|
||||
|
||||
t.deviceIndex = link.Attrs().Index
|
||||
|
||||
// Setup our default MTU
|
||||
t.setMTU()
|
||||
|
||||
@ -324,20 +358,21 @@ func (t *tun) Activate() error {
|
||||
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
||||
}
|
||||
|
||||
if err = t.addIPs(link); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Bring up the interface
|
||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP
|
||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
return fmt.Errorf("failed to bring the tun device up: %s", err)
|
||||
}
|
||||
|
||||
link, err := netlink.LinkByName(t.Device)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get tun device link: %s", err)
|
||||
}
|
||||
t.deviceIndex = link.Attrs().Index
|
||||
|
||||
if err = t.setDefaultRoute(); err != nil {
|
||||
return err
|
||||
//set route MTU
|
||||
for i := range t.vpnNetworks {
|
||||
if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
|
||||
return fmt.Errorf("failed to set default route MTU: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set the routes
|
||||
@ -363,12 +398,10 @@ func (t *tun) setMTU() {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tun) setDefaultRoute() error {
|
||||
// Default route
|
||||
|
||||
func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
||||
dr := &net.IPNet{
|
||||
IP: t.cidr.Masked().Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()),
|
||||
IP: cidr.Masked().Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
|
||||
}
|
||||
|
||||
nr := netlink.Route{
|
||||
@ -377,14 +410,27 @@ func (t *tun) setDefaultRoute() error {
|
||||
MTU: t.DefaultMTU,
|
||||
AdvMSS: t.advMSS(Route{}),
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Src: net.IP(t.cidr.Addr().AsSlice()),
|
||||
Src: net.IP(cidr.Addr().AsSlice()),
|
||||
Protocol: unix.RTPROT_KERNEL,
|
||||
Table: unix.RT_TABLE_MAIN,
|
||||
Type: unix.RTN_UNICAST,
|
||||
}
|
||||
err := netlink.RouteReplace(&nr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
|
||||
t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
|
||||
//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
|
||||
for i := 0; i < 2; i++ {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err = netlink.RouteReplace(&nr)
|
||||
if err == nil {
|
||||
break
|
||||
} else {
|
||||
t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying")
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -463,10 +509,6 @@ func (t *tun) removeRoutes(routes []Route) {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
}
|
||||
|
||||
func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
@ -515,7 +557,6 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||
return
|
||||
}
|
||||
|
||||
//TODO: IPV6-WORK what if not ok?
|
||||
gwAddr, ok := netip.AddrFromSlice(r.Gw)
|
||||
if !ok {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
|
||||
@ -523,15 +564,16 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||
}
|
||||
|
||||
gwAddr = gwAddr.Unmap()
|
||||
if !t.cidr.Contains(gwAddr) {
|
||||
// Gateway isn't in our overlay network, ignore
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||
return
|
||||
withinNetworks := false
|
||||
for i := range t.vpnNetworks {
|
||||
if t.vpnNetworks[i].Contains(gwAddr) {
|
||||
withinNetworks = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if x := r.Dst.IP.To4(); x == nil {
|
||||
// Nebula only handles ipv4 on the overlay currently
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4")
|
||||
if !withinNetworks {
|
||||
// Gateway isn't in our overlay network, ignore
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
|
||||
return
|
||||
}
|
||||
|
||||
@ -563,11 +605,11 @@ func (t *tun) Close() error {
|
||||
}
|
||||
|
||||
if t.ReadWriteCloser != nil {
|
||||
t.ReadWriteCloser.Close()
|
||||
_ = t.ReadWriteCloser.Close()
|
||||
}
|
||||
|
||||
if t.ioctlFd > 0 {
|
||||
os.NewFile(t.ioctlFd, "ioctlFd").Close()
|
||||
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@ -27,12 +27,12 @@ type ifreqDestroy struct {
|
||||
}
|
||||
|
||||
type tun struct {
|
||||
Device string
|
||||
cidr netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
l *logrus.Logger
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
@ -58,13 +58,13 @@ func (t *tun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var file *os.File
|
||||
var err error
|
||||
@ -84,7 +84,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
Device: deviceName,
|
||||
cidr: cidr,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
@ -104,17 +104,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||
var err error
|
||||
|
||||
// TODO use syscalls instead of exec.Command
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String())
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
@ -130,8 +130,18 @@ func (t *tun) Activate() error {
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
for i := range t.vpnNetworks {
|
||||
err := t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -172,8 +182,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() netip.Prefix {
|
||||
return t.cidr
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *tun) Name() string {
|
||||
@ -192,7 +202,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
continue
|
||||
}
|
||||
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String())
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
||||
@ -213,7 +223,8 @@ func (t *tun) removeRoutes(routes []Route) error {
|
||||
continue
|
||||
}
|
||||
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String())
|
||||
//TODO: CERT-V2 is this right?
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user