diff --git a/.github/workflows/gofmt.yml b/.github/workflows/gofmt.yml index e0d41ae..288f32c 100644 --- a/.github/workflows/gofmt.yml +++ b/.github/workflows/gofmt.yml @@ -18,7 +18,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.24' check-latest: true - name: Install goimports diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c8cf3f8..3107b47 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,7 +14,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.24' check-latest: true - name: Build @@ -37,7 +37,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.24' check-latest: true - name: Build @@ -64,18 +64,18 @@ jobs: name: Build Universal Darwin env: HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }} - runs-on: macos-11 + runs-on: macos-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.24' check-latest: true - name: Import certificates if: env.HAS_SIGNING_CREDS == 'true' - uses: Apple-Actions/import-codesign-certs@v2 + uses: Apple-Actions/import-codesign-certs@v5 with: p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }} p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }} diff --git a/.github/workflows/smoke-extra.yml b/.github/workflows/smoke-extra.yml index 2b5e6e9..de582de 100644 --- a/.github/workflows/smoke-extra.yml +++ b/.github/workflows/smoke-extra.yml @@ -27,6 +27,9 @@ jobs: go-version-file: 'go.mod' check-latest: true + - name: add hashicorp source + run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list + - name: install vagrant run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox diff --git a/.github/workflows/smoke.yml b/.github/workflows/smoke.yml index f7a73d3..2560085 100644 --- a/.github/workflows/smoke.yml +++ b/.github/workflows/smoke.yml @@ -22,7 +22,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.24' check-latest: true - name: build diff --git a/.github/workflows/smoke/build.sh b/.github/workflows/smoke/build.sh index c546653..dcd132b 100755 --- a/.github/workflows/smoke/build.sh +++ b/.github/workflows/smoke/build.sh @@ -5,6 +5,10 @@ set -e -x rm -rf ./build mkdir ./build +# TODO: Assumes your docker bridge network is a /24, and the first container that launches will be .1 +# - We could make this better by launching the lighthouse first and then fetching what IP it is. +NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ end }}' | cut -d. -f1-3)" + ( cd build @@ -21,16 +25,16 @@ mkdir ./build ../genconfig.sh >lighthouse1.yml HOST="host2" \ - LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \ + LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ ../genconfig.sh >host2.yml HOST="host3" \ - LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \ + LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ ../genconfig.sh >host3.yml HOST="host4" \ - LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \ + LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ ../genconfig.sh >host4.yml diff --git a/.github/workflows/smoke/smoke-vagrant.sh b/.github/workflows/smoke/smoke-vagrant.sh index 76cf72f..1c1e3c5 100755 --- a/.github/workflows/smoke/smoke-vagrant.sh +++ b/.github/workflows/smoke/smoke-vagrant.sh @@ -29,13 +29,13 @@ docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test docker run --name host2 --rm "$CONTAINER" -config host2.yml -test vagrant up -vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" +vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 -vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" & +vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 15 # grab tcpdump pcaps for debugging @@ -46,8 +46,8 @@ docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host # vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap & # vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap & -docker exec host2 ncat -nklv 0.0.0.0 2000 & -vagrant ssh -c "ncat -nklv 0.0.0.0 2000" & +#docker exec host2 ncat -nklv 0.0.0.0 2000 & +#vagrant ssh -c "ncat -nklv 0.0.0.0 2000" & #docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 & #vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" & @@ -68,11 +68,11 @@ docker exec host2 ping -c1 192.168.100.1 # Should fail because not allowed by host3 inbound firewall ! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1 -set +x -echo -echo " *** Testing ncat from host2" -echo -set -x +#set +x +#echo +#echo " *** Testing ncat from host2" +#echo +#set -x # Should fail because not allowed by host3 inbound firewall #! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1 #! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1 @@ -82,18 +82,18 @@ echo echo " *** Testing ping from host3" echo set -x -vagrant ssh -c "ping -c1 192.168.100.1" -vagrant ssh -c "ping -c1 192.168.100.2" +vagrant ssh -c "ping -c1 192.168.100.1" -- -T +vagrant ssh -c "ping -c1 192.168.100.2" -- -T -set +x -echo -echo " *** Testing ncat from host3" -echo -set -x +#set +x +#echo +#echo " *** Testing ncat from host3" +#echo +#set -x #vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000" #vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2 -vagrant ssh -c "sudo xargs kill math.MaxUint16 { return nil, fmt.Errorf("invalid port: %d", port) } return &calculatedRemote{ - ipNet: *ipNet, - maskIP: iputil.Ip2VpnIp(ipNet.IP), - mask: iputil.Ip2VpnIp(ipNet.Mask), - port: uint32(port), + ipNet: maskCidr, + mask: masked, + port: uint32(port), }, nil } @@ -43,21 +42,47 @@ func (c *calculatedRemote) String() string { return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port) } -func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort { - // Combine the masked bytes of the "mask" IP with the unmasked bytes - // of the overlay IP - masked := (c.maskIP & c.mask) | (ip & ^c.mask) +func (c *calculatedRemote) ApplyV4(addr netip.Addr) *V4AddrPort { + // Combine the masked bytes of the "mask" IP with the unmasked bytes of the overlay IP + maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) + mask := binary.BigEndian.Uint32(maskb[:]) - return &Ip4AndPort{Ip: uint32(masked), Port: c.port} + b := c.mask.Addr().As4() + maskAddr := binary.BigEndian.Uint32(b[:]) + + b = addr.As4() + intAddr := binary.BigEndian.Uint32(b[:]) + + return &V4AddrPort{(maskAddr & mask) | (intAddr & ^mask), c.port} } -func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) { +func (c *calculatedRemote) ApplyV6(addr netip.Addr) *V6AddrPort { + mask := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) + maskAddr := c.mask.Addr().As16() + calcAddr := addr.As16() + + ap := V6AddrPort{Port: c.port} + + maskb := binary.BigEndian.Uint64(mask[:8]) + maskAddrb := binary.BigEndian.Uint64(maskAddr[:8]) + calcAddrb := binary.BigEndian.Uint64(calcAddr[:8]) + ap.Hi = (maskAddrb & maskb) | (calcAddrb & ^maskb) + + maskb = binary.BigEndian.Uint64(mask[8:]) + maskAddrb = binary.BigEndian.Uint64(maskAddr[8:]) + calcAddrb = binary.BigEndian.Uint64(calcAddr[8:]) + ap.Lo = (maskAddrb & maskb) | (calcAddrb & ^maskb) + + return &ap +} + +func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) { value := c.Get(k) if value == nil { return nil, nil } - calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]() + calculatedRemotes := new(bart.Table[[]*calculatedRemote]) rawMap, ok := value.(map[any]any) if !ok { @@ -69,23 +94,23 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calcu return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) } - _, ipNet, err := net.ParseCIDR(rawCIDR) + cidr, err := netip.ParsePrefix(rawCIDR) if err != nil { return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) } - entry, err := newCalculatedRemotesListFromConfig(rawValue) + entry, err := newCalculatedRemotesListFromConfig(cidr, rawValue) if err != nil { return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err) } - calculatedRemotes.AddCIDR(ipNet, entry) + calculatedRemotes.Insert(cidr, entry) } return calculatedRemotes, nil } -func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { +func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculatedRemote, error) { rawList, ok := raw.([]any) if !ok { return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw) @@ -93,7 +118,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { var l []*calculatedRemote for _, e := range rawList { - c, err := newCalculatedRemotesEntryFromConfig(e) + c, err := newCalculatedRemotesEntryFromConfig(cidr, e) if err != nil { return nil, fmt.Errorf("calculated_remotes entry: %w", err) } @@ -103,7 +128,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { return l, nil } -func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { +func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) { rawMap, ok := raw.(map[any]any) if !ok { return nil, fmt.Errorf("invalid type: %T", raw) @@ -117,7 +142,7 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { if !ok { return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue) } - _, ipNet, err := net.ParseCIDR(rawMask) + maskCidr, err := netip.ParsePrefix(rawMask) if err != nil { return nil, fmt.Errorf("invalid mask: %s", rawMask) } @@ -139,5 +164,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue) } - return newCalculatedRemote(ipNet, port) + return newCalculatedRemote(cidr, maskCidr, port) } diff --git a/calculated_remote_test.go b/calculated_remote_test.go index 2ddebca..6df893c 100644 --- a/calculated_remote_test.go +++ b/calculated_remote_test.go @@ -1,27 +1,81 @@ package nebula import ( - "net" + "net/netip" "testing" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestCalculatedRemoteApply(t *testing.T) { - _, ipNet, err := net.ParseCIDR("192.168.1.0/24") + // Test v4 addresses + ipNet := netip.MustParsePrefix("192.168.1.0/24") + c, err := newCalculatedRemote(ipNet, ipNet, 4242) require.NoError(t, err) - c, err := newCalculatedRemote(ipNet, 4242) + input, err := netip.ParseAddr("10.0.10.182") require.NoError(t, err) - input := iputil.Ip2VpnIp([]byte{10, 0, 10, 182}) + expected, err := netip.ParseAddr("192.168.1.182") + require.NoError(t, err) - expected := &Ip4AndPort{ - Ip: uint32(iputil.Ip2VpnIp([]byte{192, 168, 1, 182})), - Port: 4242, - } + assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input)) - assert.Equal(t, expected, c.Apply(input)) + // Test v6 addresses + ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff::0/64") + c, err = newCalculatedRemote(ipNet, ipNet, 4242) + require.NoError(t, err) + + input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") + require.NoError(t, err) + + expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef") + require.NoError(t, err) + + assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) + + // Test v6 addresses part 2 + ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff:ffff::0/80") + c, err = newCalculatedRemote(ipNet, ipNet, 4242) + require.NoError(t, err) + + input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") + require.NoError(t, err) + + expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef") + require.NoError(t, err) + + assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) + + // Test v6 addresses part 2 + ipNet = netip.MustParsePrefix("ffff:ffff:ffff::0/48") + c, err = newCalculatedRemote(ipNet, ipNet, 4242) + require.NoError(t, err) + + input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") + require.NoError(t, err) + + expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef") + require.NoError(t, err) + + assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) +} + +func Test_newCalculatedRemote(t *testing.T) { + c, err := newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1.0.0.0/32"), 4242) + require.EqualError(t, err, "invalid mask: 1.0.0.0/32 for cidr: 1::1/128") + require.Nil(t, c) + + c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1::1/128"), 4242) + require.EqualError(t, err, "invalid mask: 1::1/128 for cidr: 1.0.0.0/32") + require.Nil(t, c) + + c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1.0.0.0/32"), 4242) + require.NoError(t, err) + require.NotNil(t, c) + + c, err = newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1::1/128"), 4242) + require.NoError(t, err) + require.NotNil(t, c) } diff --git a/cert/Makefile b/cert/Makefile index 28170b6..311afc2 100644 --- a/cert/Makefile +++ b/cert/Makefile @@ -1,7 +1,7 @@ GO111MODULE = on export GO111MODULE -cert.pb.go: cert.proto .FORCE +cert_v1.pb.go: cert_v1.proto .FORCE go build google.golang.org/protobuf/cmd/protoc-gen-go PATH="$(CURDIR):$(PATH)" protoc --go_out=. --go_opt=paths=source_relative $< rm protoc-gen-go diff --git a/cert/README.md b/cert/README.md index ae19a28..1e27a6b 100644 --- a/cert/README.md +++ b/cert/README.md @@ -2,14 +2,25 @@ This is a library for interacting with `nebula` style certificates and authorities. -A `protobuf` definition of the certificate format is also included +There are now 2 versions of `nebula` certificates: -### Compiling the protobuf definition +## v1 -Make sure you have `protoc` installed. +This version is deprecated. + +A `protobuf` definition of the certificate format is included at `cert_v1.proto` + +To compile the definition you will need `protoc` installed. To compile for `go` with the same version of protobuf specified in go.mod: ```bash -make +make proto ``` + +## v2 + +This is the latest version which uses asn.1 DER encoding. It can support ipv4 and ipv6 and tolerate +future certificate changes better than v1. + +`cert_v2.asn1` defines the wire format and can be used to compile marshalers. \ No newline at end of file diff --git a/cert/asn1.go b/cert/asn1.go new file mode 100644 index 0000000..6bf6a8d --- /dev/null +++ b/cert/asn1.go @@ -0,0 +1,52 @@ +package cert + +import ( + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" +) + +// readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value +// https://github.com/golang/go/issues/64811#issuecomment-1944446920 +func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool { + var present bool + var child cryptobyte.String + if !b.ReadOptionalASN1(&child, &present, tag) { + return false + } + + if !present { + *out = defaultValue + return true + } + + // Ensure we have 1 byte + if len(child) == 1 { + *out = child[0] > 0 + return true + } + + return false +} + +// readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value +// Similar issue as with readOptionalASN1Boolean +func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool { + var present bool + var child cryptobyte.String + if !b.ReadOptionalASN1(&child, &present, tag) { + return false + } + + if !present { + *out = defaultValue + return true + } + + // Ensure we have 1 byte + if len(child) == 1 { + *out = child[0] + return true + } + + return false +} diff --git a/cert/ca.go b/cert/ca.go deleted file mode 100644 index 0ffbd87..0000000 --- a/cert/ca.go +++ /dev/null @@ -1,140 +0,0 @@ -package cert - -import ( - "errors" - "fmt" - "strings" - "time" -) - -type NebulaCAPool struct { - CAs map[string]*NebulaCertificate - certBlocklist map[string]struct{} -} - -// NewCAPool creates a CAPool -func NewCAPool() *NebulaCAPool { - ca := NebulaCAPool{ - CAs: make(map[string]*NebulaCertificate), - certBlocklist: make(map[string]struct{}), - } - - return &ca -} - -// NewCAPoolFromBytes will create a new CA pool from the provided -// input bytes, which must be a PEM-encoded set of nebula certificates. -// If the pool contains any expired certificates, an ErrExpired will be -// returned along with the pool. The caller must handle any such errors. -func NewCAPoolFromBytes(caPEMs []byte) (*NebulaCAPool, error) { - pool := NewCAPool() - var err error - var expired bool - for { - caPEMs, err = pool.AddCACertificate(caPEMs) - if errors.Is(err, ErrExpired) { - expired = true - err = nil - } - if err != nil { - return nil, err - } - if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" { - break - } - } - - if expired { - return pool, ErrExpired - } - - return pool, nil -} - -// AddCACertificate verifies a Nebula CA certificate and adds it to the pool -// Only the first pem encoded object will be consumed, any remaining bytes are returned. -// Parsed certificates will be verified and must be a CA -func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) { - c, pemBytes, err := UnmarshalNebulaCertificateFromPEM(pemBytes) - if err != nil { - return pemBytes, err - } - - if !c.Details.IsCA { - return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotCA) - } - - if !c.CheckSignature(c.Details.PublicKey) { - return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotSelfSigned) - } - - sum, err := c.Sha256Sum() - if err != nil { - return pemBytes, fmt.Errorf("could not calculate shasum for provided CA; error: %s; %s", err, c.Details.Name) - } - - ncp.CAs[sum] = c - if c.Expired(time.Now()) { - return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrExpired) - } - - return pemBytes, nil -} - -// BlocklistFingerprint adds a cert fingerprint to the blocklist -func (ncp *NebulaCAPool) BlocklistFingerprint(f string) { - ncp.certBlocklist[f] = struct{}{} -} - -// ResetCertBlocklist removes all previously blocklisted cert fingerprints -func (ncp *NebulaCAPool) ResetCertBlocklist() { - ncp.certBlocklist = make(map[string]struct{}) -} - -// NOTE: This uses an internal cache for Sha256Sum() that will not be invalidated -// automatically if you manually change any fields in the NebulaCertificate. -func (ncp *NebulaCAPool) IsBlocklisted(c *NebulaCertificate) bool { - return ncp.isBlocklistedWithCache(c, false) -} - -// IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted -func (ncp *NebulaCAPool) isBlocklistedWithCache(c *NebulaCertificate, useCache bool) bool { - h, err := c.sha256SumWithCache(useCache) - if err != nil { - return true - } - - if _, ok := ncp.certBlocklist[h]; ok { - return true - } - - return false -} - -// GetCAForCert attempts to return the signing certificate for the provided certificate. -// No signature validation is performed -func (ncp *NebulaCAPool) GetCAForCert(c *NebulaCertificate) (*NebulaCertificate, error) { - if c.Details.Issuer == "" { - return nil, fmt.Errorf("no issuer in certificate") - } - - signer, ok := ncp.CAs[c.Details.Issuer] - if ok { - return signer, nil - } - - return nil, fmt.Errorf("could not find ca for the certificate") -} - -// GetFingerprints returns an array of trusted CA fingerprints -func (ncp *NebulaCAPool) GetFingerprints() []string { - fp := make([]string, len(ncp.CAs)) - - i := 0 - for k := range ncp.CAs { - fp[i] = k - i++ - } - - return fp -} diff --git a/cert/ca_pool.go b/cert/ca_pool.go new file mode 100644 index 0000000..2bf480f --- /dev/null +++ b/cert/ca_pool.go @@ -0,0 +1,296 @@ +package cert + +import ( + "errors" + "fmt" + "net/netip" + "slices" + "strings" + "time" +) + +type CAPool struct { + CAs map[string]*CachedCertificate + certBlocklist map[string]struct{} +} + +// NewCAPool creates an empty CAPool +func NewCAPool() *CAPool { + ca := CAPool{ + CAs: make(map[string]*CachedCertificate), + certBlocklist: make(map[string]struct{}), + } + + return &ca +} + +// NewCAPoolFromPEM will create a new CA pool from the provided +// input bytes, which must be a PEM-encoded set of nebula certificates. +// If the pool contains any expired certificates, an ErrExpired will be +// returned along with the pool. The caller must handle any such errors. +func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) { + pool := NewCAPool() + var err error + var expired bool + for { + caPEMs, err = pool.AddCAFromPEM(caPEMs) + if errors.Is(err, ErrExpired) { + expired = true + err = nil + } + if err != nil { + return nil, err + } + if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" { + break + } + } + + if expired { + return pool, ErrExpired + } + + return pool, nil +} + +// AddCAFromPEM verifies a Nebula CA certificate and adds it to the pool. +// Only the first pem encoded object will be consumed, any remaining bytes are returned. +// Parsed certificates will be verified and must be a CA +func (ncp *CAPool) AddCAFromPEM(pemBytes []byte) ([]byte, error) { + c, pemBytes, err := UnmarshalCertificateFromPEM(pemBytes) + if err != nil { + return pemBytes, err + } + + err = ncp.AddCA(c) + if err != nil { + return pemBytes, err + } + + return pemBytes, nil +} + +// AddCA verifies a Nebula CA certificate and adds it to the pool. +func (ncp *CAPool) AddCA(c Certificate) error { + if !c.IsCA() { + return fmt.Errorf("%s: %w", c.Name(), ErrNotCA) + } + + if !c.CheckSignature(c.PublicKey()) { + return fmt.Errorf("%s: %w", c.Name(), ErrNotSelfSigned) + } + + sum, err := c.Fingerprint() + if err != nil { + return fmt.Errorf("could not calculate fingerprint for provided CA; error: %w; %s", err, c.Name()) + } + + cc := &CachedCertificate{ + Certificate: c, + Fingerprint: sum, + InvertedGroups: make(map[string]struct{}), + } + + for _, g := range c.Groups() { + cc.InvertedGroups[g] = struct{}{} + } + + ncp.CAs[sum] = cc + + if c.Expired(time.Now()) { + return fmt.Errorf("%s: %w", c.Name(), ErrExpired) + } + + return nil +} + +// BlocklistFingerprint adds a cert fingerprint to the blocklist +func (ncp *CAPool) BlocklistFingerprint(f string) { + ncp.certBlocklist[f] = struct{}{} +} + +// ResetCertBlocklist removes all previously blocklisted cert fingerprints +func (ncp *CAPool) ResetCertBlocklist() { + ncp.certBlocklist = make(map[string]struct{}) +} + +// IsBlocklisted tests the provided fingerprint against the pools blocklist. +// Returns true if the fingerprint is blocked. +func (ncp *CAPool) IsBlocklisted(fingerprint string) bool { + if _, ok := ncp.certBlocklist[fingerprint]; ok { + return true + } + + return false +} + +// VerifyCertificate verifies the certificate is valid and is signed by a trusted CA in the pool. +// If the certificate is valid then the returned CachedCertificate can be used in subsequent verification attempts +// to increase performance. +func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCertificate, error) { + if c == nil { + return nil, fmt.Errorf("no certificate") + } + fp, err := c.Fingerprint() + if err != nil { + return nil, fmt.Errorf("could not calculate fingerprint to verify: %w", err) + } + + signer, err := ncp.verify(c, now, fp, "") + if err != nil { + return nil, err + } + + cc := CachedCertificate{ + Certificate: c, + InvertedGroups: make(map[string]struct{}), + Fingerprint: fp, + signerFingerprint: signer.Fingerprint, + } + + for _, g := range c.Groups() { + cc.InvertedGroups[g] = struct{}{} + } + + return &cc, nil +} + +// VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and +// is a cheaper operation to perform as a result. +func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error { + _, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint) + return err +} + +func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp string) (*CachedCertificate, error) { + if ncp.IsBlocklisted(certFp) { + return nil, ErrBlockListed + } + + signer, err := ncp.GetCAForCert(c) + if err != nil { + return nil, err + } + + if signer.Certificate.Expired(now) { + return nil, ErrRootExpired + } + + if c.Expired(now) { + return nil, ErrExpired + } + + // If we are checking a cached certificate then we can bail early here + // Either the root is no longer trusted or everything is fine + if len(signerFp) > 0 { + if signerFp != signer.Fingerprint { + return nil, ErrFingerprintMismatch + } + return signer, nil + } + if !c.CheckSignature(signer.Certificate.PublicKey()) { + return nil, ErrSignatureMismatch + } + + err = CheckCAConstraints(signer.Certificate, c) + if err != nil { + return nil, err + } + + return signer, nil +} + +// GetCAForCert attempts to return the signing certificate for the provided certificate. +// No signature validation is performed +func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) { + issuer := c.Issuer() + if issuer == "" { + return nil, fmt.Errorf("no issuer in certificate") + } + + signer, ok := ncp.CAs[issuer] + if ok { + return signer, nil + } + + return nil, ErrCaNotFound +} + +// GetFingerprints returns an array of trusted CA fingerprints +func (ncp *CAPool) GetFingerprints() []string { + fp := make([]string, len(ncp.CAs)) + + i := 0 + for k := range ncp.CAs { + fp[i] = k + i++ + } + + return fp +} + +// CheckCAConstraints returns an error if the sub certificate violates constraints present in the signer certificate. +func CheckCAConstraints(signer Certificate, sub Certificate) error { + return checkCAConstraints(signer, sub.NotBefore(), sub.NotAfter(), sub.Groups(), sub.Networks(), sub.UnsafeNetworks()) +} + +// checkCAConstraints is a very generic function allowing both Certificates and TBSCertificates to be tested. +func checkCAConstraints(signer Certificate, notBefore, notAfter time.Time, groups []string, networks, unsafeNetworks []netip.Prefix) error { + // Make sure this cert isn't valid after the root + if notAfter.After(signer.NotAfter()) { + return fmt.Errorf("certificate expires after signing certificate") + } + + // Make sure this cert wasn't valid before the root + if notBefore.Before(signer.NotBefore()) { + return fmt.Errorf("certificate is valid before the signing certificate") + } + + // If the signer has a limited set of groups make sure the cert only contains a subset + signerGroups := signer.Groups() + if len(signerGroups) > 0 { + for _, g := range groups { + if !slices.Contains(signerGroups, g) { + return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g) + } + } + } + + // If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset + signingNetworks := signer.Networks() + if len(signingNetworks) > 0 { + for _, certNetwork := range networks { + found := false + for _, signingNetwork := range signingNetworks { + if signingNetwork.Contains(certNetwork.Addr()) && signingNetwork.Bits() <= certNetwork.Bits() { + found = true + break + } + } + + if !found { + return fmt.Errorf("certificate contained a network assignment outside the limitations of the signing ca: %s", certNetwork.String()) + } + } + } + + // If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset + signingUnsafeNetworks := signer.UnsafeNetworks() + if len(signingUnsafeNetworks) > 0 { + for _, certUnsafeNetwork := range unsafeNetworks { + found := false + for _, caNetwork := range signingUnsafeNetworks { + if caNetwork.Contains(certUnsafeNetwork.Addr()) && caNetwork.Bits() <= certUnsafeNetwork.Bits() { + found = true + break + } + } + + if !found { + return fmt.Errorf("certificate contained an unsafe network assignment outside the limitations of the signing ca: %s", certUnsafeNetwork.String()) + } + } + } + + return nil +} diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go new file mode 100644 index 0000000..b0fdd5f --- /dev/null +++ b/cert/ca_pool_test.go @@ -0,0 +1,560 @@ +package cert + +import ( + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewCAPoolFromBytes(t *testing.T) { + noNewLines := ` +# Current provisional, Remove once everything moves over to the real root. +-----BEGIN NEBULA CERTIFICATE----- +Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+ +PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf +2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ== +-----END NEBULA CERTIFICATE----- +# root-ca01 +-----BEGIN NEBULA CERTIFICATE----- +CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br +BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye +rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA== +-----END NEBULA CERTIFICATE----- +` + + withNewLines := ` +# Current provisional, Remove once everything moves over to the real root. + +-----BEGIN NEBULA CERTIFICATE----- +Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+ +PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf +2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ== +-----END NEBULA CERTIFICATE----- + +# root-ca01 + + +-----BEGIN NEBULA CERTIFICATE----- +CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br +BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye +rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA== +-----END NEBULA CERTIFICATE----- + +` + + expired := ` +# expired certificate +-----BEGIN NEBULA CERTIFICATE----- +CjMKB2V4cGlyZWQozRwwzRw6ICJSG94CqX8wn5I65Pwn25V6HftVfWeIySVtp2DA +7TY/QAESQMaAk5iJT5EnQwK524ZaaHGEJLUqqbh5yyOHhboIGiVTWkFeH3HccTW8 +Tq5a8AyWDQdfXbtEZ1FwabeHfH5Asw0= +-----END NEBULA CERTIFICATE----- +` + + p256 := ` +# p256 certificate +-----BEGIN NEBULA CERTIFICATE----- +CmQKEG5lYnVsYSBQMjU2IHRlc3QozRwwzbjM8K8HOkEEdrmmg40zQp44AkMq6DZp +k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe ++0ABoAYBEkcwRQIgVoTg38L7uWku9xQgsr06kxZ/viQLOO/w1Qj1vFUEnhcCIQCq +75SjTiV92kv/1GcbT3wWpAZQQDBiUHVMVmh1822szA== +-----END NEBULA CERTIFICATE----- +` + + rootCA := certificateV1{ + details: detailsV1{ + name: "nebula root ca", + }, + } + + rootCA01 := certificateV1{ + details: detailsV1{ + name: "nebula root ca 01", + }, + } + + rootCAP256 := certificateV1{ + details: detailsV1{ + name: "nebula P256 test", + }, + } + + p, err := NewCAPoolFromPEM([]byte(noNewLines)) + require.NoError(t, err) + assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) + assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) + + pp, err := NewCAPoolFromPEM([]byte(withNewLines)) + require.NoError(t, err) + assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) + assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) + + // expired cert, no valid certs + ppp, err := NewCAPoolFromPEM([]byte(expired)) + assert.Equal(t, ErrExpired, err) + assert.Equal(t, "expired", ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name()) + + // expired cert, with valid certs + pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...)) + assert.Equal(t, ErrExpired, err) + assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) + assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) + assert.Equal(t, "expired", pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name()) + assert.Len(t, pppp.CAs, 3) + + ppppp, err := NewCAPoolFromPEM([]byte(p256)) + require.NoError(t, err) + assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name) + assert.Len(t, ppppp.CAs, 1) +} + +func TestCertificateV1_Verify(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + + caPool := NewCAPool() + require.NoError(t, caPool.AddCA(ca)) + + f, err := c.Fingerprint() + require.NoError(t, err) + caPool.BlocklistFingerprint(f) + + _, err = caPool.VerifyCertificate(time.Now(), c) + require.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) + require.EqualError(t, err, "root certificate is expired") + + assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) + }) + + // Test group assertion + ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + caPem, err := ca.MarshalPEM() + require.NoError(t, err) + + caPool = NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + require.NoError(t, err) + assert.Empty(t, b) + + assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + }) + + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + require.NoError(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) +} + +func TestCertificateV1_VerifyP256(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + + caPool := NewCAPool() + require.NoError(t, caPool.AddCA(ca)) + + f, err := c.Fingerprint() + require.NoError(t, err) + caPool.BlocklistFingerprint(f) + + _, err = caPool.VerifyCertificate(time.Now(), c) + require.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) + require.EqualError(t, err, "root certificate is expired") + + assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { + NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + }) + + // Test group assertion + ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + caPem, err := ca.MarshalPEM() + require.NoError(t, err) + + caPool = NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + require.NoError(t, err) + assert.Empty(t, b) + + assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { + NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + }) + + c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) +} + +func TestCertificateV1_Verify_IPs(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + + caPem, err := ca.MarshalPEM() + require.NoError(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + require.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is outside the network reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is within the network but mask is outside + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is within the network but mask is outside reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip and mask are within the network + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Exact matches + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + require.NoError(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Exact matches reversed + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) + require.NoError(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Exact matches reversed with just 1 + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) + require.NoError(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) +} + +func TestCertificateV1_Verify_Subnets(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + + caPem, err := ca.MarshalPEM() + require.NoError(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + require.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is outside the network reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is within the network but mask is outside + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is within the network but mask is outside reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip and mask are within the network + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + require.NoError(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Exact matches + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + require.NoError(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Exact matches reversed + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) + require.NoError(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Exact matches reversed with just 1 + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) + require.NoError(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) +} + +func TestCertificateV2_Verify(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + + caPool := NewCAPool() + require.NoError(t, caPool.AddCA(ca)) + + f, err := c.Fingerprint() + require.NoError(t, err) + caPool.BlocklistFingerprint(f) + + _, err = caPool.VerifyCertificate(time.Now(), c) + require.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) + require.EqualError(t, err, "root certificate is expired") + + assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) + }) + + // Test group assertion + ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + caPem, err := ca.MarshalPEM() + require.NoError(t, err) + + caPool = NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + require.NoError(t, err) + assert.Empty(t, b) + + assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + }) + + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + require.NoError(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) +} + +func TestCertificateV2_VerifyP256(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + + caPool := NewCAPool() + require.NoError(t, caPool.AddCA(ca)) + + f, err := c.Fingerprint() + require.NoError(t, err) + caPool.BlocklistFingerprint(f) + + _, err = caPool.VerifyCertificate(time.Now(), c) + require.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) + require.EqualError(t, err, "root certificate is expired") + + assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { + NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + }) + + // Test group assertion + ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + caPem, err := ca.MarshalPEM() + require.NoError(t, err) + + caPool = NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + require.NoError(t, err) + assert.Empty(t, b) + + assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { + NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + }) + + c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) +} + +func TestCertificateV2_Verify_IPs(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + + caPem, err := ca.MarshalPEM() + require.NoError(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + require.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is outside the network reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is within the network but mask is outside + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is within the network but mask is outside reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip and mask are within the network + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Exact matches + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + require.NoError(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Exact matches reversed + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) + require.NoError(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Exact matches reversed with just 1 + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) + require.NoError(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) +} + +func TestCertificateV2_Verify_Subnets(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + + caPem, err := ca.MarshalPEM() + require.NoError(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + require.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is outside the network reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is within the network but mask is outside + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is within the network but mask is outside reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip and mask are within the network + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + require.NoError(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Exact matches + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + require.NoError(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Exact matches reversed + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) + require.NoError(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Exact matches reversed with just 1 + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) + require.NoError(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) +} diff --git a/cert/cert.go b/cert/cert.go index a0164f7..38a2528 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -1,1029 +1,151 @@ package cert import ( - "bytes" - "crypto/ecdh" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" - "crypto/rand" - "crypto/sha256" - "encoding/binary" - "encoding/hex" - "encoding/json" - "encoding/pem" - "errors" "fmt" - "math" - "math/big" - "net" - "sync/atomic" + "net/netip" "time" - - "golang.org/x/crypto/curve25519" - "google.golang.org/protobuf/proto" ) -const publicKeyLen = 32 +type Version uint8 const ( - CertBanner = "NEBULA CERTIFICATE" - X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" - X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" - EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY" - Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" - Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" - - P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY" - P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY" - EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY" - ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY" + VersionPre1 Version = 0 + Version1 Version = 1 + Version2 Version = 2 ) -type NebulaCertificate struct { - Details NebulaCertificateDetails - Signature []byte +type Certificate interface { + // Version defines the underlying certificate structure and wire protocol version + // Version1 certificates are ipv4 only and uses protobuf serialization + // Version2 certificates are ipv4 or ipv6 and uses asn.1 serialization + Version() Version - // the cached hex string of the calculated sha256sum - // for VerifyWithCache - sha256sum atomic.Pointer[string] + // Name is the human-readable name that identifies this certificate. + Name() string - // the cached public key bytes if they were verified as the signer - // for VerifyWithCache - signatureVerified atomic.Pointer[[]byte] + // Networks is a list of ip addresses and network sizes assigned to this certificate. + // If IsCA is true then certificates signed by this CA can only have ip addresses and + // networks that are contained by an entry in this list. + Networks() []netip.Prefix + + // UnsafeNetworks is a list of networks that this host can act as an unsafe router for. + // If IsCA is true then certificates signed by this CA can only have networks that are + // contained by an entry in this list. + UnsafeNetworks() []netip.Prefix + + // Groups is a list of identities that can be used to write more general firewall rule + // definitions. + // If IsCA is true then certificates signed by this CA can only use groups that are + // in this list. + Groups() []string + + // IsCA signifies if this is a certificate authority (true) or a host certificate (false). + // It is invalid to use a CA certificate as a host certificate. + IsCA() bool + + // NotBefore is the time at which this certificate becomes valid. + // If IsCA is true then certificate signed by this CA can not have a time before this. + NotBefore() time.Time + + // NotAfter is the time at which this certificate becomes invalid. + // If IsCA is true then certificate signed by this CA can not have a time after this. + NotAfter() time.Time + + // Issuer is the fingerprint of the CA that signed this certificate. + // If IsCA is true then this will be empty. + Issuer() string + + // PublicKey is the raw bytes to be used in asymmetric cryptographic operations. + PublicKey() []byte + + // Curve identifies which curve was used for the PublicKey and Signature. + Curve() Curve + + // Signature is the cryptographic seal for all the details of this certificate. + // CheckSignature can be used to verify that the details of this certificate are valid. + Signature() []byte + + // CheckSignature will check that the certificate Signature() matches the + // computed signature. A true result means this certificate has not been tampered with. + CheckSignature(signingPublicKey []byte) bool + + // Fingerprint returns the hex encoded sha256 sum of the certificate. + // This acts as a unique fingerprint and can be used to blocklist certificates. + Fingerprint() (string, error) + + // Expired tests if the certificate is valid for the provided time. + Expired(t time.Time) bool + + // VerifyPrivateKey returns an error if the private key is not a pair with the certificates public key. + VerifyPrivateKey(curve Curve, privateKey []byte) error + + // Marshal will return the byte representation of this certificate + // This is primarily the format transmitted on the wire. + Marshal() ([]byte, error) + + // MarshalForHandshakes prepares the bytes needed to use directly in a handshake + MarshalForHandshakes() ([]byte, error) + + // MarshalPEM will return a PEM encoded representation of this certificate + // This is primarily the format stored on disk + MarshalPEM() ([]byte, error) + + // MarshalJSON will return the json representation of this certificate + MarshalJSON() ([]byte, error) + + // String will return a human-readable representation of this certificate + String() string + + // Copy creates a copy of the certificate + Copy() Certificate } -type NebulaCertificateDetails struct { - Name string - Ips []*net.IPNet - Subnets []*net.IPNet - Groups []string - NotBefore time.Time - NotAfter time.Time - PublicKey []byte - IsCA bool - Issuer string - - // Map of groups for faster lookup - InvertedGroups map[string]struct{} - - Curve Curve +// CachedCertificate represents a verified certificate with some cached fields to improve +// performance. +type CachedCertificate struct { + Certificate Certificate + InvertedGroups map[string]struct{} + Fingerprint string + signerFingerprint string } -type NebulaEncryptedData struct { - EncryptionMetadata NebulaEncryptionMetadata - Ciphertext []byte +func (cc *CachedCertificate) String() string { + return cc.Certificate.String() } -type NebulaEncryptionMetadata struct { - EncryptionAlgorithm string - Argon2Parameters Argon2Parameters -} - -type m map[string]interface{} - -// Returned if we try to unmarshal an encrypted private key without a passphrase -var ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") - -// UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert -func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) { - if len(b) == 0 { - return nil, fmt.Errorf("nil byte array") +// Recombine will attempt to unmarshal a certificate received in a handshake. +// Handshakes save space by placing the peers public key in a different part of the packet, we have to +// reassemble the actual certificate structure with that in mind. +func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certificate, error) { + if publicKey == nil { + return nil, ErrNoPeerStaticKey } - var rc RawNebulaCertificate - err := proto.Unmarshal(b, &rc) + + if rawCertBytes == nil { + return nil, ErrNoPayload + } + + var c Certificate + var err error + + switch v { + // Implementations must ensure the result is a valid cert! + case VersionPre1, Version1: + c, err = unmarshalCertificateV1(rawCertBytes, publicKey) + case Version2: + c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve) + default: + //TODO: CERT-V2 make a static var + return nil, fmt.Errorf("unknown certificate version %d", v) + } + if err != nil { return nil, err } - if rc.Details == nil { - return nil, fmt.Errorf("encoded Details was nil") + if c.Curve() != curve { + return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String()) } - if len(rc.Details.Ips)%2 != 0 { - return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found") - } - - if len(rc.Details.Subnets)%2 != 0 { - return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found") - } - - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: rc.Details.Name, - Groups: make([]string, len(rc.Details.Groups)), - Ips: make([]*net.IPNet, len(rc.Details.Ips)/2), - Subnets: make([]*net.IPNet, len(rc.Details.Subnets)/2), - NotBefore: time.Unix(rc.Details.NotBefore, 0), - NotAfter: time.Unix(rc.Details.NotAfter, 0), - PublicKey: make([]byte, len(rc.Details.PublicKey)), - IsCA: rc.Details.IsCA, - InvertedGroups: make(map[string]struct{}), - Curve: rc.Details.Curve, - }, - Signature: make([]byte, len(rc.Signature)), - } - - copy(nc.Signature, rc.Signature) - copy(nc.Details.Groups, rc.Details.Groups) - nc.Details.Issuer = hex.EncodeToString(rc.Details.Issuer) - - if len(rc.Details.PublicKey) < publicKeyLen { - return nil, fmt.Errorf("Public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey)) - } - copy(nc.Details.PublicKey, rc.Details.PublicKey) - - for i, rawIp := range rc.Details.Ips { - if i%2 == 0 { - nc.Details.Ips[i/2] = &net.IPNet{IP: int2ip(rawIp)} - } else { - nc.Details.Ips[i/2].Mask = net.IPMask(int2ip(rawIp)) - } - } - - for i, rawIp := range rc.Details.Subnets { - if i%2 == 0 { - nc.Details.Subnets[i/2] = &net.IPNet{IP: int2ip(rawIp)} - } else { - nc.Details.Subnets[i/2].Mask = net.IPMask(int2ip(rawIp)) - } - } - - for _, g := range rc.Details.Groups { - nc.Details.InvertedGroups[g] = struct{}{} - } - - return &nc, nil -} - -// UnmarshalNebulaCertificateFromPEM will unmarshal the first pem block in a byte array, returning any non consumed data -// or an error on failure -func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) { - p, r := pem.Decode(b) - if p == nil { - return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - if p.Type != CertBanner { - return nil, r, fmt.Errorf("bytes did not contain a proper nebula certificate banner") - } - nc, err := UnmarshalNebulaCertificate(p.Bytes) - return nc, r, err -} - -func MarshalPrivateKey(curve Curve, b []byte) []byte { - switch curve { - case Curve_CURVE25519: - return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b}) - case Curve_P256: - return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b}) - default: - return nil - } -} - -func MarshalSigningPrivateKey(curve Curve, b []byte) []byte { - switch curve { - case Curve_CURVE25519: - return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b}) - case Curve_P256: - return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b}) - default: - return nil - } -} - -// MarshalX25519PrivateKey is a simple helper to PEM encode an X25519 private key -func MarshalX25519PrivateKey(b []byte) []byte { - return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b}) -} - -// MarshalEd25519PrivateKey is a simple helper to PEM encode an Ed25519 private key -func MarshalEd25519PrivateKey(key ed25519.PrivateKey) []byte { - return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: key}) -} - -func UnmarshalPrivateKey(b []byte) ([]byte, []byte, Curve, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") - } - var expectedLen int - var curve Curve - switch k.Type { - case X25519PrivateKeyBanner: - expectedLen = 32 - curve = Curve_CURVE25519 - case P256PrivateKeyBanner: - expectedLen = 32 - curve = Curve_P256 - default: - return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula private key banner") - } - if len(k.Bytes) != expectedLen { - return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve) - } - return k.Bytes, r, curve, nil -} - -func UnmarshalSigningPrivateKey(b []byte) ([]byte, []byte, Curve, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") - } - var curve Curve - switch k.Type { - case EncryptedEd25519PrivateKeyBanner: - return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted - case EncryptedECDSAP256PrivateKeyBanner: - return nil, nil, Curve_P256, ErrPrivateKeyEncrypted - case Ed25519PrivateKeyBanner: - curve = Curve_CURVE25519 - if len(k.Bytes) != ed25519.PrivateKeySize { - return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize) - } - case ECDSAP256PrivateKeyBanner: - curve = Curve_P256 - if len(k.Bytes) != 32 { - return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key") - } - default: - return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula Ed25519/ECDSA private key banner") - } - return k.Bytes, r, curve, nil -} - -// EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key -func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) { - ciphertext, err := aes256Encrypt(passphrase, kdfParams, b) - if err != nil { - return nil, err - } - - b, err = proto.Marshal(&RawNebulaEncryptedData{ - EncryptionMetadata: &RawNebulaEncryptionMetadata{ - EncryptionAlgorithm: "AES-256-GCM", - Argon2Parameters: &RawNebulaArgon2Parameters{ - Version: kdfParams.version, - Memory: kdfParams.Memory, - Parallelism: uint32(kdfParams.Parallelism), - Iterations: kdfParams.Iterations, - Salt: kdfParams.salt, - }, - }, - Ciphertext: ciphertext, - }) - if err != nil { - return nil, err - } - - switch curve { - case Curve_CURVE25519: - return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil - case Curve_P256: - return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil - default: - return nil, fmt.Errorf("invalid curve: %v", curve) - } -} - -// UnmarshalX25519PrivateKey will try to pem decode an X25519 private key, returning any other bytes b -// or an error on failure -func UnmarshalX25519PrivateKey(b []byte) ([]byte, []byte, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - if k.Type != X25519PrivateKeyBanner { - return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 private key banner") - } - if len(k.Bytes) != publicKeyLen { - return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 private key") - } - - return k.Bytes, r, nil -} - -// UnmarshalEd25519PrivateKey will try to pem decode an Ed25519 private key, returning any other bytes b -// or an error on failure -func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - - if k.Type == EncryptedEd25519PrivateKeyBanner { - return nil, r, ErrPrivateKeyEncrypted - } else if k.Type != Ed25519PrivateKeyBanner { - return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 private key banner") - } - - if len(k.Bytes) != ed25519.PrivateKeySize { - return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") - } - - return k.Bytes, r, nil -} - -// UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its -// protobuf-generated struct. -func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) { - if len(b) == 0 { - return nil, fmt.Errorf("nil byte array") - } - var rned RawNebulaEncryptedData - err := proto.Unmarshal(b, &rned) - if err != nil { - return nil, err - } - - if rned.EncryptionMetadata == nil { - return nil, fmt.Errorf("encoded EncryptionMetadata was nil") - } - - if rned.EncryptionMetadata.Argon2Parameters == nil { - return nil, fmt.Errorf("encoded Argon2Parameters was nil") - } - - params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters) - if err != nil { - return nil, err - } - - ned := NebulaEncryptedData{ - EncryptionMetadata: NebulaEncryptionMetadata{ - EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm, - Argon2Parameters: *params, - }, - Ciphertext: rned.Ciphertext, - } - - return &ned, nil -} - -func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) { - if params.Version < math.MinInt32 || params.Version > math.MaxInt32 { - return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32) - } - if params.Memory <= 0 || params.Memory > math.MaxUint32 { - return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32)) - } - if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 { - return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8) - } - if params.Iterations <= 0 || params.Iterations > math.MaxUint32 { - return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32)) - } - - return &Argon2Parameters{ - version: rune(params.Version), - Memory: uint32(params.Memory), - Parallelism: uint8(params.Parallelism), - Iterations: uint32(params.Iterations), - salt: params.Salt, - }, nil - -} - -// DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with -// the given passphrase, returning any other bytes b or an error on failure -func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) { - var curve Curve - - k, r := pem.Decode(b) - if k == nil { - return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - - switch k.Type { - case EncryptedEd25519PrivateKeyBanner: - curve = Curve_CURVE25519 - case EncryptedECDSAP256PrivateKeyBanner: - curve = Curve_P256 - default: - return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") - } - - ned, err := UnmarshalNebulaEncryptedData(k.Bytes) - if err != nil { - return curve, nil, r, err - } - - var bytes []byte - switch ned.EncryptionMetadata.EncryptionAlgorithm { - case "AES-256-GCM": - bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext) - if err != nil { - return curve, nil, r, err - } - default: - return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm) - } - - switch curve { - case Curve_CURVE25519: - if len(bytes) != ed25519.PrivateKeySize { - return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize) - } - case Curve_P256: - if len(bytes) != 32 { - return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key") - } - } - - return curve, bytes, r, nil -} - -func MarshalPublicKey(curve Curve, b []byte) []byte { - switch curve { - case Curve_CURVE25519: - return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b}) - case Curve_P256: - return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b}) - default: - return nil - } -} - -// MarshalX25519PublicKey is a simple helper to PEM encode an X25519 public key -func MarshalX25519PublicKey(b []byte) []byte { - return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b}) -} - -// MarshalEd25519PublicKey is a simple helper to PEM encode an Ed25519 public key -func MarshalEd25519PublicKey(key ed25519.PublicKey) []byte { - return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: key}) -} - -func UnmarshalPublicKey(b []byte) ([]byte, []byte, Curve, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") - } - var expectedLen int - var curve Curve - switch k.Type { - case X25519PublicKeyBanner: - expectedLen = 32 - curve = Curve_CURVE25519 - case P256PublicKeyBanner: - // Uncompressed - expectedLen = 65 - curve = Curve_P256 - default: - return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula public key banner") - } - if len(k.Bytes) != expectedLen { - return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve) - } - return k.Bytes, r, curve, nil -} - -// UnmarshalX25519PublicKey will try to pem decode an X25519 public key, returning any other bytes b -// or an error on failure -func UnmarshalX25519PublicKey(b []byte) ([]byte, []byte, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - if k.Type != X25519PublicKeyBanner { - return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 public key banner") - } - if len(k.Bytes) != publicKeyLen { - return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 public key") - } - - return k.Bytes, r, nil -} - -// UnmarshalEd25519PublicKey will try to pem decode an Ed25519 public key, returning any other bytes b -// or an error on failure -func UnmarshalEd25519PublicKey(b []byte) (ed25519.PublicKey, []byte, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - if k.Type != Ed25519PublicKeyBanner { - return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 public key banner") - } - if len(k.Bytes) != ed25519.PublicKeySize { - return nil, r, fmt.Errorf("key was not 32 bytes, is invalid ed25519 public key") - } - - return k.Bytes, r, nil -} - -// Sign signs a nebula cert with the provided private key -func (nc *NebulaCertificate) Sign(curve Curve, key []byte) error { - if curve != nc.Details.Curve { - return fmt.Errorf("curve in cert and private key supplied don't match") - } - - b, err := proto.Marshal(nc.getRawDetails()) - if err != nil { - return err - } - - var sig []byte - - switch curve { - case Curve_CURVE25519: - signer := ed25519.PrivateKey(key) - sig = ed25519.Sign(signer, b) - case Curve_P256: - signer := &ecdsa.PrivateKey{ - PublicKey: ecdsa.PublicKey{ - Curve: elliptic.P256(), - }, - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 - D: new(big.Int).SetBytes(key), - } - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 - signer.X, signer.Y = signer.Curve.ScalarBaseMult(key) - - // We need to hash first for ECDSA - // - https://pkg.go.dev/crypto/ecdsa#SignASN1 - hashed := sha256.Sum256(b) - sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:]) - if err != nil { - return err - } - default: - return fmt.Errorf("invalid curve: %s", nc.Details.Curve) - } - - nc.Signature = sig - return nil -} - -// CheckSignature verifies the signature against the provided public key -func (nc *NebulaCertificate) CheckSignature(key []byte) bool { - b, err := proto.Marshal(nc.getRawDetails()) - if err != nil { - return false - } - switch nc.Details.Curve { - case Curve_CURVE25519: - return ed25519.Verify(ed25519.PublicKey(key), b, nc.Signature) - case Curve_P256: - x, y := elliptic.Unmarshal(elliptic.P256(), key) - pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} - hashed := sha256.Sum256(b) - return ecdsa.VerifyASN1(pubKey, hashed[:], nc.Signature) - default: - return false - } -} - -// NOTE: This uses an internal cache that will not be invalidated automatically -// if you manually change any fields in the NebulaCertificate. -func (nc *NebulaCertificate) checkSignatureWithCache(key []byte, useCache bool) bool { - if !useCache { - return nc.CheckSignature(key) - } - - if v := nc.signatureVerified.Load(); v != nil { - return bytes.Equal(*v, key) - } - - verified := nc.CheckSignature(key) - if verified { - keyCopy := make([]byte, len(key)) - copy(keyCopy, key) - nc.signatureVerified.Store(&keyCopy) - } - - return verified -} - -// Expired will return true if the nebula cert is too young or too old compared to the provided time, otherwise false -func (nc *NebulaCertificate) Expired(t time.Time) bool { - return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t) -} - -// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) -func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) { - return nc.verify(t, ncp, false) -} - -// VerifyWithCache will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) -// -// NOTE: This uses an internal cache that will not be invalidated automatically -// if you manually change any fields in the NebulaCertificate. -func (nc *NebulaCertificate) VerifyWithCache(t time.Time, ncp *NebulaCAPool) (bool, error) { - return nc.verify(t, ncp, true) -} - -// ResetCache resets the cache used by VerifyWithCache. -func (nc *NebulaCertificate) ResetCache() { - nc.sha256sum.Store(nil) - nc.signatureVerified.Store(nil) -} - -// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) -func (nc *NebulaCertificate) verify(t time.Time, ncp *NebulaCAPool, useCache bool) (bool, error) { - if ncp.isBlocklistedWithCache(nc, useCache) { - return false, ErrBlockListed - } - - signer, err := ncp.GetCAForCert(nc) - if err != nil { - return false, err - } - - if signer.Expired(t) { - return false, ErrRootExpired - } - - if nc.Expired(t) { - return false, ErrExpired - } - - if !nc.checkSignatureWithCache(signer.Details.PublicKey, useCache) { - return false, ErrSignatureMismatch - } - - if err := nc.CheckRootConstrains(signer); err != nil { - return false, err - } - - return true, nil -} - -// CheckRootConstrains returns an error if the certificate violates constraints set on the root (groups, ips, subnets) -func (nc *NebulaCertificate) CheckRootConstrains(signer *NebulaCertificate) error { - // Make sure this cert wasn't valid before the root - if signer.Details.NotAfter.Before(nc.Details.NotAfter) { - return fmt.Errorf("certificate expires after signing certificate") - } - - // Make sure this cert isn't valid after the root - if signer.Details.NotBefore.After(nc.Details.NotBefore) { - return fmt.Errorf("certificate is valid before the signing certificate") - } - - // If the signer has a limited set of groups make sure the cert only contains a subset - if len(signer.Details.InvertedGroups) > 0 { - for _, g := range nc.Details.Groups { - if _, ok := signer.Details.InvertedGroups[g]; !ok { - return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g) - } - } - } - - // If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset - if len(signer.Details.Ips) > 0 { - for _, ip := range nc.Details.Ips { - if !netMatch(ip, signer.Details.Ips) { - return fmt.Errorf("certificate contained an ip assignment outside the limitations of the signing ca: %s", ip.String()) - } - } - } - - // If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset - if len(signer.Details.Subnets) > 0 { - for _, subnet := range nc.Details.Subnets { - if !netMatch(subnet, signer.Details.Subnets) { - return fmt.Errorf("certificate contained a subnet assignment outside the limitations of the signing ca: %s", subnet) - } - } - } - - return nil -} - -// VerifyPrivateKey checks that the public key in the Nebula certificate and a supplied private key match -func (nc *NebulaCertificate) VerifyPrivateKey(curve Curve, key []byte) error { - if curve != nc.Details.Curve { - return fmt.Errorf("curve in cert and private key supplied don't match") - } - if nc.Details.IsCA { - switch curve { - case Curve_CURVE25519: - // the call to PublicKey below will panic slice bounds out of range otherwise - if len(key) != ed25519.PrivateKeySize { - return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") - } - - if !ed25519.PublicKey(nc.Details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) { - return fmt.Errorf("public key in cert and private key supplied don't match") - } - case Curve_P256: - privkey, err := ecdh.P256().NewPrivateKey(key) - if err != nil { - return fmt.Errorf("cannot parse private key as P256") - } - pub := privkey.PublicKey().Bytes() - if !bytes.Equal(pub, nc.Details.PublicKey) { - return fmt.Errorf("public key in cert and private key supplied don't match") - } - default: - return fmt.Errorf("invalid curve: %s", curve) - } - return nil - } - - var pub []byte - switch curve { - case Curve_CURVE25519: - var err error - pub, err = curve25519.X25519(key, curve25519.Basepoint) - if err != nil { - return err - } - case Curve_P256: - privkey, err := ecdh.P256().NewPrivateKey(key) - if err != nil { - return err - } - pub = privkey.PublicKey().Bytes() - default: - return fmt.Errorf("invalid curve: %s", curve) - } - if !bytes.Equal(pub, nc.Details.PublicKey) { - return fmt.Errorf("public key in cert and private key supplied don't match") - } - - return nil -} - -// String will return a pretty printed representation of a nebula cert -func (nc *NebulaCertificate) String() string { - if nc == nil { - return "NebulaCertificate {}\n" - } - - s := "NebulaCertificate {\n" - s += "\tDetails {\n" - s += fmt.Sprintf("\t\tName: %v\n", nc.Details.Name) - - if len(nc.Details.Ips) > 0 { - s += "\t\tIps: [\n" - for _, ip := range nc.Details.Ips { - s += fmt.Sprintf("\t\t\t%v\n", ip.String()) - } - s += "\t\t]\n" - } else { - s += "\t\tIps: []\n" - } - - if len(nc.Details.Subnets) > 0 { - s += "\t\tSubnets: [\n" - for _, ip := range nc.Details.Subnets { - s += fmt.Sprintf("\t\t\t%v\n", ip.String()) - } - s += "\t\t]\n" - } else { - s += "\t\tSubnets: []\n" - } - - if len(nc.Details.Groups) > 0 { - s += "\t\tGroups: [\n" - for _, g := range nc.Details.Groups { - s += fmt.Sprintf("\t\t\t\"%v\"\n", g) - } - s += "\t\t]\n" - } else { - s += "\t\tGroups: []\n" - } - - s += fmt.Sprintf("\t\tNot before: %v\n", nc.Details.NotBefore) - s += fmt.Sprintf("\t\tNot After: %v\n", nc.Details.NotAfter) - s += fmt.Sprintf("\t\tIs CA: %v\n", nc.Details.IsCA) - s += fmt.Sprintf("\t\tIssuer: %s\n", nc.Details.Issuer) - s += fmt.Sprintf("\t\tPublic key: %x\n", nc.Details.PublicKey) - s += fmt.Sprintf("\t\tCurve: %s\n", nc.Details.Curve) - s += "\t}\n" - fp, err := nc.Sha256Sum() - if err == nil { - s += fmt.Sprintf("\tFingerprint: %s\n", fp) - } - s += fmt.Sprintf("\tSignature: %x\n", nc.Signature) - s += "}" - - return s -} - -// getRawDetails marshals the raw details into protobuf ready struct -func (nc *NebulaCertificate) getRawDetails() *RawNebulaCertificateDetails { - rd := &RawNebulaCertificateDetails{ - Name: nc.Details.Name, - Groups: nc.Details.Groups, - NotBefore: nc.Details.NotBefore.Unix(), - NotAfter: nc.Details.NotAfter.Unix(), - PublicKey: make([]byte, len(nc.Details.PublicKey)), - IsCA: nc.Details.IsCA, - Curve: nc.Details.Curve, - } - - for _, ipNet := range nc.Details.Ips { - rd.Ips = append(rd.Ips, ip2int(ipNet.IP), ip2int(ipNet.Mask)) - } - - for _, ipNet := range nc.Details.Subnets { - rd.Subnets = append(rd.Subnets, ip2int(ipNet.IP), ip2int(ipNet.Mask)) - } - - copy(rd.PublicKey, nc.Details.PublicKey[:]) - - // I know, this is terrible - rd.Issuer, _ = hex.DecodeString(nc.Details.Issuer) - - return rd -} - -// Marshal will marshal a nebula cert into a protobuf byte array -func (nc *NebulaCertificate) Marshal() ([]byte, error) { - rc := RawNebulaCertificate{ - Details: nc.getRawDetails(), - Signature: nc.Signature, - } - - return proto.Marshal(&rc) -} - -// MarshalToPEM will marshal a nebula cert into a protobuf byte array and pem encode the result -func (nc *NebulaCertificate) MarshalToPEM() ([]byte, error) { - b, err := nc.Marshal() - if err != nil { - return nil, err - } - return pem.EncodeToMemory(&pem.Block{Type: CertBanner, Bytes: b}), nil -} - -// Sha256Sum calculates a sha-256 sum of the marshaled certificate -func (nc *NebulaCertificate) Sha256Sum() (string, error) { - b, err := nc.Marshal() - if err != nil { - return "", err - } - - sum := sha256.Sum256(b) - return hex.EncodeToString(sum[:]), nil -} - -// NOTE: This uses an internal cache that will not be invalidated automatically -// if you manually change any fields in the NebulaCertificate. -func (nc *NebulaCertificate) sha256SumWithCache(useCache bool) (string, error) { - if !useCache { - return nc.Sha256Sum() - } - - if s := nc.sha256sum.Load(); s != nil { - return *s, nil - } - s, err := nc.Sha256Sum() - if err != nil { - return s, err - } - - nc.sha256sum.Store(&s) - return s, nil -} - -func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) { - toString := func(ips []*net.IPNet) []string { - s := []string{} - for _, ip := range ips { - s = append(s, ip.String()) - } - return s - } - - fp, _ := nc.Sha256Sum() - jc := m{ - "details": m{ - "name": nc.Details.Name, - "ips": toString(nc.Details.Ips), - "subnets": toString(nc.Details.Subnets), - "groups": nc.Details.Groups, - "notBefore": nc.Details.NotBefore, - "notAfter": nc.Details.NotAfter, - "publicKey": fmt.Sprintf("%x", nc.Details.PublicKey), - "isCa": nc.Details.IsCA, - "issuer": nc.Details.Issuer, - "curve": nc.Details.Curve.String(), - }, - "fingerprint": fp, - "signature": fmt.Sprintf("%x", nc.Signature), - } - return json.Marshal(jc) -} - -//func (nc *NebulaCertificate) Copy() *NebulaCertificate { -// r, err := nc.Marshal() -// if err != nil { -// //TODO -// return nil -// } -// -// c, err := UnmarshalNebulaCertificate(r) -// return c -//} - -func (nc *NebulaCertificate) Copy() *NebulaCertificate { - c := &NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: nc.Details.Name, - Groups: make([]string, len(nc.Details.Groups)), - Ips: make([]*net.IPNet, len(nc.Details.Ips)), - Subnets: make([]*net.IPNet, len(nc.Details.Subnets)), - NotBefore: nc.Details.NotBefore, - NotAfter: nc.Details.NotAfter, - PublicKey: make([]byte, len(nc.Details.PublicKey)), - IsCA: nc.Details.IsCA, - Issuer: nc.Details.Issuer, - InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)), - }, - Signature: make([]byte, len(nc.Signature)), - } - - copy(c.Signature, nc.Signature) - copy(c.Details.Groups, nc.Details.Groups) - copy(c.Details.PublicKey, nc.Details.PublicKey) - - for i, p := range nc.Details.Ips { - c.Details.Ips[i] = &net.IPNet{ - IP: make(net.IP, len(p.IP)), - Mask: make(net.IPMask, len(p.Mask)), - } - copy(c.Details.Ips[i].IP, p.IP) - copy(c.Details.Ips[i].Mask, p.Mask) - } - - for i, p := range nc.Details.Subnets { - c.Details.Subnets[i] = &net.IPNet{ - IP: make(net.IP, len(p.IP)), - Mask: make(net.IPMask, len(p.Mask)), - } - copy(c.Details.Subnets[i].IP, p.IP) - copy(c.Details.Subnets[i].Mask, p.Mask) - } - - for g := range nc.Details.InvertedGroups { - c.Details.InvertedGroups[g] = struct{}{} - } - - return c -} - -func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool { - for _, net := range rootIps { - if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) { - return true - } - } - - return false -} - -func maskContains(caMask, certMask net.IPMask) bool { - caM := maskTo4(caMask) - cM := maskTo4(certMask) - // Make sure forcing to ipv4 didn't nuke us - if caM == nil || cM == nil { - return false - } - - // Make sure the cert mask is not greater than the ca mask - for i := 0; i < len(caMask); i++ { - if caM[i] > cM[i] { - return false - } - } - - return true -} - -func maskTo4(ip net.IPMask) net.IPMask { - if len(ip) == net.IPv4len { - return ip - } - - if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff { - return ip[12:16] - } - - return nil -} - -func isZeros(b []byte) bool { - for i := 0; i < len(b); i++ { - if b[i] != 0 { - return false - } - } - return true -} - -func ip2int(ip []byte) uint32 { - if len(ip) == 16 { - return binary.BigEndian.Uint32(ip[12:16]) - } - return binary.BigEndian.Uint32(ip) -} - -func int2ip(nn uint32) net.IP { - ip := make(net.IP, net.IPv4len) - binary.BigEndian.PutUint32(ip, nn) - return ip + return c, nil } diff --git a/cert/cert_test.go b/cert/cert_test.go deleted file mode 100644 index 30e99ec..0000000 --- a/cert/cert_test.go +++ /dev/null @@ -1,1230 +0,0 @@ -package cert - -import ( - "crypto/ecdh" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "fmt" - "io" - "net" - "testing" - "time" - - "github.com/slackhq/nebula/test" - "github.com/stretchr/testify/assert" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/ed25519" - "google.golang.org/protobuf/proto" -) - -func TestMarshalingNebulaCertificate(t *testing.T) { - before := time.Now().Add(time.Second * -60).Round(time.Second) - after := time.Now().Add(time.Second * 60).Round(time.Second) - pubKey := []byte("1234567890abcedfghij1234567890ab") - - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: before, - NotAfter: after, - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", - }, - Signature: []byte("1234567890abcedfghij1234567890ab"), - } - - b, err := nc.Marshal() - assert.Nil(t, err) - //t.Log("Cert size:", len(b)) - - nc2, err := UnmarshalNebulaCertificate(b) - assert.Nil(t, err) - - assert.Equal(t, nc.Signature, nc2.Signature) - assert.Equal(t, nc.Details.Name, nc2.Details.Name) - assert.Equal(t, nc.Details.NotBefore, nc2.Details.NotBefore) - assert.Equal(t, nc.Details.NotAfter, nc2.Details.NotAfter) - assert.Equal(t, nc.Details.PublicKey, nc2.Details.PublicKey) - assert.Equal(t, nc.Details.IsCA, nc2.Details.IsCA) - - // IP byte arrays can be 4 or 16 in length so we have to go this route - assert.Equal(t, len(nc.Details.Ips), len(nc2.Details.Ips)) - for i, wIp := range nc.Details.Ips { - assert.Equal(t, wIp.String(), nc2.Details.Ips[i].String()) - } - - assert.Equal(t, len(nc.Details.Subnets), len(nc2.Details.Subnets)) - for i, wIp := range nc.Details.Subnets { - assert.Equal(t, wIp.String(), nc2.Details.Subnets[i].String()) - } - - assert.EqualValues(t, nc.Details.Groups, nc2.Details.Groups) -} - -func TestNebulaCertificate_Sign(t *testing.T) { - before := time.Now().Add(time.Second * -60).Round(time.Second) - after := time.Now().Add(time.Second * 60).Round(time.Second) - pubKey := []byte("1234567890abcedfghij1234567890ab") - - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: before, - NotAfter: after, - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", - }, - } - - pub, priv, err := ed25519.GenerateKey(rand.Reader) - assert.Nil(t, err) - assert.False(t, nc.CheckSignature(pub)) - assert.Nil(t, nc.Sign(Curve_CURVE25519, priv)) - assert.True(t, nc.CheckSignature(pub)) - - _, err = nc.Marshal() - assert.Nil(t, err) - //t.Log("Cert size:", len(b)) -} - -func TestNebulaCertificate_SignP256(t *testing.T) { - before := time.Now().Add(time.Second * -60).Round(time.Second) - after := time.Now().Add(time.Second * 60).Round(time.Second) - pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") - - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: before, - NotAfter: after, - PublicKey: pubKey, - IsCA: false, - Curve: Curve_P256, - Issuer: "1234567890abcedfghij1234567890ab", - }, - } - - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) - rawPriv := priv.D.FillBytes(make([]byte, 32)) - - assert.Nil(t, err) - assert.False(t, nc.CheckSignature(pub)) - assert.Nil(t, nc.Sign(Curve_P256, rawPriv)) - assert.True(t, nc.CheckSignature(pub)) - - _, err = nc.Marshal() - assert.Nil(t, err) - //t.Log("Cert size:", len(b)) -} - -func TestNebulaCertificate_Expired(t *testing.T) { - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - NotBefore: time.Now().Add(time.Second * -60).Round(time.Second), - NotAfter: time.Now().Add(time.Second * 60).Round(time.Second), - }, - } - - assert.True(t, nc.Expired(time.Now().Add(time.Hour))) - assert.True(t, nc.Expired(time.Now().Add(-time.Hour))) - assert.False(t, nc.Expired(time.Now())) -} - -func TestNebulaCertificate_MarshalJSON(t *testing.T) { - time.Local = time.UTC - pubKey := []byte("1234567890abcedfghij1234567890ab") - - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), - NotAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", - }, - Signature: []byte("1234567890abcedfghij1234567890ab"), - } - - b, err := nc.MarshalJSON() - assert.Nil(t, err) - assert.Equal( - t, - "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\",\"10.1.1.3/ff00ff00\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.1/ff00ff00\",\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"26cb1c30ad7872c804c166b5150fa372f437aa3856b04edb4334b4470ec728e4\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}", - string(b), - ) -} - -func TestNebulaCertificate_Verify(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - - h, err := ca.Sha256Sum() - assert.Nil(t, err) - - caPool := NewCAPool() - caPool.CAs[h] = ca - - f, err := c.Sha256Sum() - assert.Nil(t, err) - caPool.BlocklistFingerprint(f) - - v, err := c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate is in the block list") - - caPool.ResetCertBlocklist() - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) - - v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool) - assert.False(t, v) - assert.EqualError(t, err, "root certificate is expired") - - c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - v, err = c.Verify(time.Now().Add(time.Minute*6), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate is expired") - - // Test group assertion - ca, _, caKey, err = newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "test2"}) - assert.Nil(t, err) - - caPem, err := ca.MarshalToPEM() - assert.Nil(t, err) - - caPool = NewCAPool() - caPool.AddCACertificate(caPem) - - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "bad"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad") - - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) -} - -func TestNebulaCertificate_VerifyP256(t *testing.T) { - ca, _, caKey, err := newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - - h, err := ca.Sha256Sum() - assert.Nil(t, err) - - caPool := NewCAPool() - caPool.CAs[h] = ca - - f, err := c.Sha256Sum() - assert.Nil(t, err) - caPool.BlocklistFingerprint(f) - - v, err := c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate is in the block list") - - caPool.ResetCertBlocklist() - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) - - v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool) - assert.False(t, v) - assert.EqualError(t, err, "root certificate is expired") - - c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - v, err = c.Verify(time.Now().Add(time.Minute*6), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate is expired") - - // Test group assertion - ca, _, caKey, err = newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "test2"}) - assert.Nil(t, err) - - caPem, err := ca.MarshalToPEM() - assert.Nil(t, err) - - caPool = NewCAPool() - caPool.AddCACertificate(caPem) - - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "bad"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad") - - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) -} - -func TestNebulaCertificate_Verify_IPs(t *testing.T) { - _, caIp1, _ := net.ParseCIDR("10.0.0.0/16") - _, caIp2, _ := net.ParseCIDR("192.168.0.0/24") - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{caIp1, caIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - - caPem, err := ca.MarshalToPEM() - assert.Nil(t, err) - - caPool := NewCAPool() - caPool.AddCACertificate(caPem) - - // ip is outside the network - cIp1 := &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}} - cIp2 := &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 0, 0}} - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err := c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.1.0.0/24") - - // ip is outside the network reversed order of above - cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.1.0.0/24") - - // ip is within the network but mask is outside - cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.0.1.0/15") - - // ip is within the network but mask is outside reversed order of above - cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.0.1.0/15") - - // ip and mask are within the network - cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 255, 0, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 128}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) - - // Exact matches - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp1, caIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) - - // Exact matches reversed - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp2, caIp1}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) - - // Exact matches reversed with just 1 - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp1}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) -} - -func TestNebulaCertificate_Verify_Subnets(t *testing.T) { - _, caIp1, _ := net.ParseCIDR("10.0.0.0/16") - _, caIp2, _ := net.ParseCIDR("192.168.0.0/24") - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1, caIp2}, []string{"test"}) - assert.Nil(t, err) - - caPem, err := ca.MarshalToPEM() - assert.Nil(t, err) - - caPool := NewCAPool() - caPool.AddCACertificate(caPem) - - // ip is outside the network - cIp1 := &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}} - cIp2 := &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 0, 0}} - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) - v, err := c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.1.0.0/24") - - // ip is outside the network reversed order of above - cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.1.0.0/24") - - // ip is within the network but mask is outside - cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.0.1.0/15") - - // ip is within the network but mask is outside reversed order of above - cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.0.1.0/15") - - // ip and mask are within the network - cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 255, 0, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 128}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) - - // Exact matches - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1, caIp2}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) - - // Exact matches reversed - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp2, caIp1}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) - - // Exact matches reversed with just 1 - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) -} - -func TestNebulaCertificate_VerifyPrivateKey(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.Nil(t, err) - - _, _, caKey2, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) - assert.NotNil(t, err) - - c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - err = c.VerifyPrivateKey(Curve_CURVE25519, priv) - assert.Nil(t, err) - - _, priv2 := x25519Keypair() - err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) - assert.NotNil(t, err) -} - -func TestNebulaCertificate_VerifyPrivateKeyP256(t *testing.T) { - ca, _, caKey, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - err = ca.VerifyPrivateKey(Curve_P256, caKey) - assert.Nil(t, err) - - _, _, caKey2, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - err = ca.VerifyPrivateKey(Curve_P256, caKey2) - assert.NotNil(t, err) - - c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - err = c.VerifyPrivateKey(Curve_P256, priv) - assert.Nil(t, err) - - _, priv2 := p256Keypair() - err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.NotNil(t, err) -} - -func TestNewCAPoolFromBytes(t *testing.T) { - noNewLines := ` -# Current provisional, Remove once everything moves over to the real root. ------BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB ------END NEBULA CERTIFICATE----- -# root-ca01 ------BEGIN NEBULA CERTIFICATE----- -CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG -BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf -8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF ------END NEBULA CERTIFICATE----- -` - - withNewLines := ` -# Current provisional, Remove once everything moves over to the real root. - ------BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB ------END NEBULA CERTIFICATE----- - -# root-ca01 - - ------BEGIN NEBULA CERTIFICATE----- -CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG -BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf -8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF ------END NEBULA CERTIFICATE----- - -` - - expired := ` -# expired certificate ------BEGIN NEBULA CERTIFICATE----- -CjkKB2V4cGlyZWQouPmWjQYwufmWjQY6ILCRaoCkJlqHgv5jfDN4lzLHBvDzaQm4 -vZxfu144hmgjQAESQG4qlnZi8DncvD/LDZnLgJHOaX1DWCHHEh59epVsC+BNgTie -WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs= ------END NEBULA CERTIFICATE----- -` - - p256 := ` -# p256 certificate ------BEGIN NEBULA CERTIFICATE----- -CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2 -6tctO9sPT3jOx8ES6M1nIqOhpTmZeabF/4rELDqPV4aH5jfJut798DUXql0FlF8H -76gvQAGgBgESRzBFAiEAib0/te6eMiZOKD8gdDeloMTS0wGuX2t0C7TFdUhAQzgC -IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX ------END NEBULA CERTIFICATE----- -` - - rootCA := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "nebula root ca", - }, - } - - rootCA01 := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "nebula root ca 01", - }, - } - - rootCAP256 := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "nebula P256 test", - }, - } - - p, err := NewCAPoolFromBytes([]byte(noNewLines)) - assert.Nil(t, err) - assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) - assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) - - pp, err := NewCAPoolFromBytes([]byte(withNewLines)) - assert.Nil(t, err) - assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) - assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) - - // expired cert, no valid certs - ppp, err := NewCAPoolFromBytes([]byte(expired)) - assert.Equal(t, ErrExpired, err) - assert.Equal(t, ppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired") - - // expired cert, with valid certs - pppp, err := NewCAPoolFromBytes(append([]byte(expired), noNewLines...)) - assert.Equal(t, ErrExpired, err) - assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) - assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) - assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired") - assert.Equal(t, len(pppp.CAs), 3) - - ppppp, err := NewCAPoolFromBytes([]byte(p256)) - assert.Nil(t, err) - assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Details.Name, rootCAP256.Details.Name) - assert.Equal(t, len(ppppp.CAs), 1) -} - -func appendByteSlices(b ...[]byte) []byte { - retSlice := []byte{} - for _, v := range b { - retSlice = append(retSlice, v...) - } - return retSlice -} - -func TestUnmrshalCertPEM(t *testing.T) { - goodCert := []byte(` -# A good cert ------BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB ------END NEBULA CERTIFICATE----- -`) - badBanner := []byte(`# A bad banner ------BEGIN NOT A NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB ------END NOT A NEBULA CERTIFICATE----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB --END NEBULA CERTIFICATE----`) - - certBundle := appendByteSlices(goodCert, badBanner, invalidPem) - - // Success test case - cert, rest, err := UnmarshalNebulaCertificateFromPEM(certBundle) - assert.NotNil(t, cert) - assert.Equal(t, rest, append(badBanner, invalidPem...)) - assert.Nil(t, err) - - // Fail due to invalid banner. - cert, rest, err = UnmarshalNebulaCertificateFromPEM(rest) - assert.Nil(t, cert) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper nebula certificate banner") - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - cert, rest, err = UnmarshalNebulaCertificateFromPEM(rest) - assert.Nil(t, cert) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") -} - -func TestUnmarshalSigningPrivateKey(t *testing.T) { - privKey := []byte(`# A good key ------BEGIN NEBULA ED25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== ------END NEBULA ED25519 PRIVATE KEY----- -`) - privP256Key := []byte(`# A good key ------BEGIN NEBULA ECDSA P256 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA ECDSA P256 PRIVATE KEY----- -`) - shortKey := []byte(`# A short key ------BEGIN NEBULA ED25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA ------END NEBULA ED25519 PRIVATE KEY----- -`) - invalidBanner := []byte(`# Invalid banner ------BEGIN NOT A NEBULA PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== ------END NOT A NEBULA PRIVATE KEY----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA ED25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== --END NEBULA ED25519 PRIVATE KEY-----`) - - keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem) - - // Success test case - k, rest, curve, err := UnmarshalSigningPrivateKey(keyBundle) - assert.Len(t, k, 64) - assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_CURVE25519, curve) - assert.Nil(t, err) - - // Success test case - k, rest, curve, err = UnmarshalSigningPrivateKey(rest) - assert.Len(t, k, 32) - assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_P256, curve) - assert.Nil(t, err) - - // Fail due to short key - k, rest, curve, err = UnmarshalSigningPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") - - // Fail due to invalid banner - k, rest, curve, err = UnmarshalSigningPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519/ECDSA private key banner") - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - k, rest, curve, err = UnmarshalSigningPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") -} - -func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) { - passphrase := []byte("DO NOT USE THIS KEY") - privKey := []byte(`# A good key ------BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT -oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl -+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB -qrlJ69wer3ZUHFXA ------END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -`) - shortKey := []byte(`# A key which, once decrypted, is too short ------BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7 -k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe -GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs -rQr3bdH3Oy/WiYU= ------END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -`) - invalidBanner := []byte(`# Invalid banner (not encrypted) ------BEGIN NEBULA ED25519 PRIVATE KEY----- -bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG -XgLvodMXZJuaFPssp+WwtA== ------END NEBULA ED25519 PRIVATE KEY----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT -oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl -+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB -qrlJ69wer3ZUHFXA --END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -`) - - keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem) - - // Success test case - curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle) - assert.Nil(t, err) - assert.Equal(t, Curve_CURVE25519, curve) - assert.Len(t, k, 64) - assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) - - // Fail due to short key - curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") - assert.Nil(t, k) - assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - - // Fail due to invalid banner - curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - - // Fail due to invalid passphrase - curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) - assert.EqualError(t, err, "invalid passphrase or corrupt private key") - assert.Nil(t, k) - assert.Equal(t, rest, []byte{}) -} - -func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { - // Having proved that decryption works correctly above, we can test the - // encryption function produces a value which can be decrypted - passphrase := []byte("passphrase") - bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") - kdfParams := NewArgon2Parameters(64*1024, 4, 3) - key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams) - assert.Nil(t, err) - - // Verify the "key" can be decrypted successfully - curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key) - assert.Len(t, k, 64) - assert.Equal(t, Curve_CURVE25519, curve) - assert.Equal(t, rest, []byte{}) - assert.Nil(t, err) - - // EncryptAndMarshalEd25519PrivateKey does not create any errors itself -} - -func TestUnmarshalPrivateKey(t *testing.T) { - privKey := []byte(`# A good key ------BEGIN NEBULA X25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA X25519 PRIVATE KEY----- -`) - privP256Key := []byte(`# A good key ------BEGIN NEBULA P256 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA P256 PRIVATE KEY----- -`) - shortKey := []byte(`# A short key ------BEGIN NEBULA X25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== ------END NEBULA X25519 PRIVATE KEY----- -`) - invalidBanner := []byte(`# Invalid banner ------BEGIN NOT A NEBULA PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NOT A NEBULA PRIVATE KEY----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA X25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= --END NEBULA X25519 PRIVATE KEY-----`) - - keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem) - - // Success test case - k, rest, curve, err := UnmarshalPrivateKey(keyBundle) - assert.Len(t, k, 32) - assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_CURVE25519, curve) - assert.Nil(t, err) - - // Success test case - k, rest, curve, err = UnmarshalPrivateKey(rest) - assert.Len(t, k, 32) - assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_P256, curve) - assert.Nil(t, err) - - // Fail due to short key - k, rest, curve, err = UnmarshalPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") - - // Fail due to invalid banner - k, rest, curve, err = UnmarshalPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper nebula private key banner") - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - k, rest, curve, err = UnmarshalPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") -} - -func TestUnmarshalEd25519PublicKey(t *testing.T) { - pubKey := []byte(`# A good key ------BEGIN NEBULA ED25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA ED25519 PUBLIC KEY----- -`) - shortKey := []byte(`# A short key ------BEGIN NEBULA ED25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== ------END NEBULA ED25519 PUBLIC KEY----- -`) - invalidBanner := []byte(`# Invalid banner ------BEGIN NOT A NEBULA PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NOT A NEBULA PUBLIC KEY----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA ED25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= --END NEBULA ED25519 PUBLIC KEY-----`) - - keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem) - - // Success test case - k, rest, err := UnmarshalEd25519PublicKey(keyBundle) - assert.Equal(t, len(k), 32) - assert.Nil(t, err) - assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) - - // Fail due to short key - k, rest, err = UnmarshalEd25519PublicKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid ed25519 public key") - - // Fail due to invalid banner - k, rest, err = UnmarshalEd25519PublicKey(rest) - assert.Nil(t, k) - assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519 public key banner") - assert.Equal(t, rest, invalidPem) - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - k, rest, err = UnmarshalEd25519PublicKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") -} - -func TestUnmarshalX25519PublicKey(t *testing.T) { - pubKey := []byte(`# A good key ------BEGIN NEBULA X25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA X25519 PUBLIC KEY----- -`) - pubP256Key := []byte(`# A good key ------BEGIN NEBULA P256 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA -AAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA P256 PUBLIC KEY----- -`) - shortKey := []byte(`# A short key ------BEGIN NEBULA X25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== ------END NEBULA X25519 PUBLIC KEY----- -`) - invalidBanner := []byte(`# Invalid banner ------BEGIN NOT A NEBULA PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NOT A NEBULA PUBLIC KEY----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA X25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= --END NEBULA X25519 PUBLIC KEY-----`) - - keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem) - - // Success test case - k, rest, curve, err := UnmarshalPublicKey(keyBundle) - assert.Equal(t, len(k), 32) - assert.Nil(t, err) - assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_CURVE25519, curve) - - // Success test case - k, rest, curve, err = UnmarshalPublicKey(rest) - assert.Equal(t, len(k), 65) - assert.Nil(t, err) - assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_P256, curve) - - // Fail due to short key - k, rest, curve, err = UnmarshalPublicKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") - - // Fail due to invalid banner - k, rest, curve, err = UnmarshalPublicKey(rest) - assert.Nil(t, k) - assert.EqualError(t, err, "bytes did not contain a proper nebula public key banner") - assert.Equal(t, rest, invalidPem) - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - k, rest, curve, err = UnmarshalPublicKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") -} - -// Ensure that upgrading the protobuf library does not change how certificates -// are marshalled, since this would break signature verification -func TestMarshalingNebulaCertificateConsistency(t *testing.T) { - before := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) - after := time.Date(2017, time.January, 18, 28, 40, 0, 0, time.UTC) - pubKey := []byte("1234567890abcedfghij1234567890ab") - - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: before, - NotAfter: after, - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", - }, - Signature: []byte("1234567890abcedfghij1234567890ab"), - } - - b, err := nc.Marshal() - assert.Nil(t, err) - //t.Log("Cert size:", len(b)) - assert.Equal(t, "0aa2010a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) - - b, err = proto.Marshal(nc.getRawDetails()) - assert.Nil(t, err) - //t.Log("Raw cert size:", len(b)) - assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) -} - -func TestNebulaCertificate_Copy(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - cc := c.Copy() - - test.AssertDeepCopyEqual(t, c, cc) -} - -func TestUnmarshalNebulaCertificate(t *testing.T) { - // Test that we don't panic with an invalid certificate (#332) - data := []byte("\x98\x00\x00") - _, err := UnmarshalNebulaCertificate(data) - assert.EqualError(t, err, "encoded Details was nil") -} - -func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { - pub, priv, err := ed25519.GenerateKey(rand.Reader) - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - nc := &NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "test ca", - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: true, - InvertedGroups: make(map[string]struct{}), - }, - } - - if len(ips) > 0 { - nc.Details.Ips = ips - } - - if len(subnets) > 0 { - nc.Details.Subnets = subnets - } - - if len(groups) > 0 { - nc.Details.Groups = groups - } - - err = nc.Sign(Curve_CURVE25519, priv) - if err != nil { - return nil, nil, nil, err - } - return nc, pub, priv, nil -} - -func newTestCaCertP256(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) - rawPriv := priv.D.FillBytes(make([]byte, 32)) - - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - nc := &NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "test ca", - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: true, - Curve: Curve_P256, - InvertedGroups: make(map[string]struct{}), - }, - } - - if len(ips) > 0 { - nc.Details.Ips = ips - } - - if len(subnets) > 0 { - nc.Details.Subnets = subnets - } - - if len(groups) > 0 { - nc.Details.Groups = groups - } - - err = nc.Sign(Curve_P256, rawPriv) - if err != nil { - return nil, nil, nil, err - } - return nc, pub, rawPriv, nil -} - -func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { - issuer, err := ca.Sha256Sum() - if err != nil { - return nil, nil, nil, err - } - - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - if len(groups) == 0 { - groups = []string{"test-group1", "test-group2", "test-group3"} - } - - if len(ips) == 0 { - ips = []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())}, - {IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())}, - {IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())}, - } - } - - if len(subnets) == 0 { - subnets = []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())}, - {IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())}, - {IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())}, - } - } - - var pub, rawPriv []byte - - switch ca.Details.Curve { - case Curve_CURVE25519: - pub, rawPriv = x25519Keypair() - case Curve_P256: - pub, rawPriv = p256Keypair() - default: - return nil, nil, nil, fmt.Errorf("unknown curve: %v", ca.Details.Curve) - } - - nc := &NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: ips, - Subnets: subnets, - Groups: groups, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: false, - Curve: ca.Details.Curve, - Issuer: issuer, - InvertedGroups: make(map[string]struct{}), - }, - } - - err = nc.Sign(ca.Details.Curve, key) - if err != nil { - return nil, nil, nil, err - } - - return nc, pub, rawPriv, nil -} - -func x25519Keypair() ([]byte, []byte) { - privkey := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, privkey); err != nil { - panic(err) - } - - pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) - if err != nil { - panic(err) - } - - return pubkey, privkey -} - -func p256Keypair() ([]byte, []byte) { - privkey, err := ecdh.P256().GenerateKey(rand.Reader) - if err != nil { - panic(err) - } - pubkey := privkey.PublicKey() - return pubkey.Bytes(), privkey.Bytes() -} diff --git a/cert/cert_v1.go b/cert/cert_v1.go new file mode 100644 index 0000000..71d36eb --- /dev/null +++ b/cert/cert_v1.go @@ -0,0 +1,489 @@ +package cert + +import ( + "bytes" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "encoding/json" + "encoding/pem" + "fmt" + "net" + "net/netip" + "time" + + "golang.org/x/crypto/curve25519" + "google.golang.org/protobuf/proto" +) + +const publicKeyLen = 32 + +type certificateV1 struct { + details detailsV1 + signature []byte +} + +type detailsV1 struct { + name string + networks []netip.Prefix + unsafeNetworks []netip.Prefix + groups []string + notBefore time.Time + notAfter time.Time + publicKey []byte + isCA bool + issuer string + + curve Curve +} + +type m = map[string]any + +func (c *certificateV1) Version() Version { + return Version1 +} + +func (c *certificateV1) Curve() Curve { + return c.details.curve +} + +func (c *certificateV1) Groups() []string { + return c.details.groups +} + +func (c *certificateV1) IsCA() bool { + return c.details.isCA +} + +func (c *certificateV1) Issuer() string { + return c.details.issuer +} + +func (c *certificateV1) Name() string { + return c.details.name +} + +func (c *certificateV1) Networks() []netip.Prefix { + return c.details.networks +} + +func (c *certificateV1) NotAfter() time.Time { + return c.details.notAfter +} + +func (c *certificateV1) NotBefore() time.Time { + return c.details.notBefore +} + +func (c *certificateV1) PublicKey() []byte { + return c.details.publicKey +} + +func (c *certificateV1) Signature() []byte { + return c.signature +} + +func (c *certificateV1) UnsafeNetworks() []netip.Prefix { + return c.details.unsafeNetworks +} + +func (c *certificateV1) Fingerprint() (string, error) { + b, err := c.Marshal() + if err != nil { + return "", err + } + + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]), nil +} + +func (c *certificateV1) CheckSignature(key []byte) bool { + b, err := proto.Marshal(c.getRawDetails()) + if err != nil { + return false + } + switch c.details.curve { + case Curve_CURVE25519: + return ed25519.Verify(key, b, c.signature) + case Curve_P256: + x, y := elliptic.Unmarshal(elliptic.P256(), key) + pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} + hashed := sha256.Sum256(b) + return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) + default: + return false + } +} + +func (c *certificateV1) Expired(t time.Time) bool { + return c.details.notBefore.After(t) || c.details.notAfter.Before(t) +} + +func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { + if curve != c.details.curve { + return fmt.Errorf("curve in cert and private key supplied don't match") + } + if c.details.isCA { + switch curve { + case Curve_CURVE25519: + // the call to PublicKey below will panic slice bounds out of range otherwise + if len(key) != ed25519.PrivateKeySize { + return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") + } + + if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) { + return fmt.Errorf("public key in cert and private key supplied don't match") + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return fmt.Errorf("cannot parse private key as P256: %w", err) + } + pub := privkey.PublicKey().Bytes() + if !bytes.Equal(pub, c.details.publicKey) { + return fmt.Errorf("public key in cert and private key supplied don't match") + } + default: + return fmt.Errorf("invalid curve: %s", curve) + } + return nil + } + + var pub []byte + switch curve { + case Curve_CURVE25519: + var err error + pub, err = curve25519.X25519(key, curve25519.Basepoint) + if err != nil { + return err + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return err + } + pub = privkey.PublicKey().Bytes() + default: + return fmt.Errorf("invalid curve: %s", curve) + } + if !bytes.Equal(pub, c.details.publicKey) { + return fmt.Errorf("public key in cert and private key supplied don't match") + } + + return nil +} + +// getRawDetails marshals the raw details into protobuf ready struct +func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails { + rd := &RawNebulaCertificateDetails{ + Name: c.details.name, + Groups: c.details.groups, + NotBefore: c.details.notBefore.Unix(), + NotAfter: c.details.notAfter.Unix(), + PublicKey: make([]byte, len(c.details.publicKey)), + IsCA: c.details.isCA, + Curve: c.details.curve, + } + + for _, ipNet := range c.details.networks { + mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) + rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask)) + } + + for _, ipNet := range c.details.unsafeNetworks { + mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) + rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask)) + } + + copy(rd.PublicKey, c.details.publicKey[:]) + + // I know, this is terrible + rd.Issuer, _ = hex.DecodeString(c.details.issuer) + + return rd +} + +func (c *certificateV1) String() string { + b, err := json.MarshalIndent(c.marshalJSON(), "", "\t") + if err != nil { + return fmt.Sprintf("", err) + } + return string(b) +} + +func (c *certificateV1) MarshalForHandshakes() ([]byte, error) { + pubKey := c.details.publicKey + c.details.publicKey = nil + rawCertNoKey, err := c.Marshal() + if err != nil { + return nil, err + } + c.details.publicKey = pubKey + return rawCertNoKey, nil +} + +func (c *certificateV1) Marshal() ([]byte, error) { + rc := RawNebulaCertificate{ + Details: c.getRawDetails(), + Signature: c.signature, + } + + return proto.Marshal(&rc) +} + +func (c *certificateV1) MarshalPEM() ([]byte, error) { + b, err := c.Marshal() + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil +} + +func (c *certificateV1) MarshalJSON() ([]byte, error) { + return json.Marshal(c.marshalJSON()) +} + +func (c *certificateV1) marshalJSON() m { + fp, _ := c.Fingerprint() + return m{ + "version": Version1, + "details": m{ + "name": c.details.name, + "networks": c.details.networks, + "unsafeNetworks": c.details.unsafeNetworks, + "groups": c.details.groups, + "notBefore": c.details.notBefore, + "notAfter": c.details.notAfter, + "publicKey": fmt.Sprintf("%x", c.details.publicKey), + "isCa": c.details.isCA, + "issuer": c.details.issuer, + "curve": c.details.curve.String(), + }, + "fingerprint": fp, + "signature": fmt.Sprintf("%x", c.Signature()), + } +} + +func (c *certificateV1) Copy() Certificate { + nc := &certificateV1{ + details: detailsV1{ + name: c.details.name, + notBefore: c.details.notBefore, + notAfter: c.details.notAfter, + publicKey: make([]byte, len(c.details.publicKey)), + isCA: c.details.isCA, + issuer: c.details.issuer, + curve: c.details.curve, + }, + signature: make([]byte, len(c.signature)), + } + + if c.details.groups != nil { + nc.details.groups = make([]string, len(c.details.groups)) + copy(nc.details.groups, c.details.groups) + } + + if c.details.networks != nil { + nc.details.networks = make([]netip.Prefix, len(c.details.networks)) + copy(nc.details.networks, c.details.networks) + } + + if c.details.unsafeNetworks != nil { + nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks)) + copy(nc.details.unsafeNetworks, c.details.unsafeNetworks) + } + + copy(nc.signature, c.signature) + copy(nc.details.publicKey, c.details.publicKey) + + return nc +} + +func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error { + c.details = detailsV1{ + name: t.Name, + networks: t.Networks, + unsafeNetworks: t.UnsafeNetworks, + groups: t.Groups, + notBefore: t.NotBefore, + notAfter: t.NotAfter, + publicKey: t.PublicKey, + isCA: t.IsCA, + curve: t.Curve, + issuer: t.issuer, + } + + return c.validate() +} + +func (c *certificateV1) validate() error { + // Empty names are allowed + + if len(c.details.publicKey) == 0 { + return ErrInvalidPublicKey + } + + // Original v1 rules allowed multiple networks to be present but ignored all but the first one. + // Continue to allow this behavior + if !c.details.isCA && len(c.details.networks) == 0 { + return NewErrInvalidCertificateProperties("non-CA certificates must contain exactly one network") + } + + for _, network := range c.details.networks { + if !network.IsValid() || !network.Addr().IsValid() { + return NewErrInvalidCertificateProperties("invalid network: %s", network) + } + + if network.Addr().Is6() { + return NewErrInvalidCertificateProperties("certificate may not contain IPv6 networks: %v", network) + } + + if network.Addr().IsUnspecified() { + return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network) + } + + if network.Addr().Zone() != "" { + return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network) + } + } + + for _, network := range c.details.unsafeNetworks { + if !network.IsValid() || !network.Addr().IsValid() { + return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network) + } + + if network.Addr().Is6() { + return NewErrInvalidCertificateProperties("certificate may not contain IPv6 unsafe networks: %v", network) + } + + if network.Addr().Zone() != "" { + return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network) + } + } + + // v1 doesn't bother with sort order or uniqueness of networks or unsafe networks. + // We can't modify the unmarshalled data because verification requires re-marshalling and a re-ordered + // unsafe networks would result in a different signature. + + return nil +} + +func (c *certificateV1) marshalForSigning() ([]byte, error) { + b, err := proto.Marshal(c.getRawDetails()) + if err != nil { + return nil, err + } + return b, nil +} + +func (c *certificateV1) setSignature(b []byte) error { + if len(b) == 0 { + return ErrEmptySignature + } + c.signature = b + return nil +} + +// unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert +// if the publicKey is provided here then it is not required to be present in `b` +func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) { + if len(b) == 0 { + return nil, fmt.Errorf("nil byte array") + } + var rc RawNebulaCertificate + err := proto.Unmarshal(b, &rc) + if err != nil { + return nil, err + } + + if rc.Details == nil { + return nil, fmt.Errorf("encoded Details was nil") + } + + if len(rc.Details.Ips)%2 != 0 { + return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found") + } + + if len(rc.Details.Subnets)%2 != 0 { + return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found") + } + + nc := certificateV1{ + details: detailsV1{ + name: rc.Details.Name, + groups: make([]string, len(rc.Details.Groups)), + networks: make([]netip.Prefix, len(rc.Details.Ips)/2), + unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2), + notBefore: time.Unix(rc.Details.NotBefore, 0), + notAfter: time.Unix(rc.Details.NotAfter, 0), + publicKey: make([]byte, len(rc.Details.PublicKey)), + isCA: rc.Details.IsCA, + curve: rc.Details.Curve, + }, + signature: make([]byte, len(rc.Signature)), + } + + copy(nc.signature, rc.Signature) + copy(nc.details.groups, rc.Details.Groups) + nc.details.issuer = hex.EncodeToString(rc.Details.Issuer) + + if len(publicKey) > 0 { + nc.details.publicKey = publicKey + } + + copy(nc.details.publicKey, rc.Details.PublicKey) + + var ip netip.Addr + for i, rawIp := range rc.Details.Ips { + if i%2 == 0 { + ip = int2addr(rawIp) + } else { + ones, _ := net.IPMask(int2ip(rawIp)).Size() + nc.details.networks[i/2] = netip.PrefixFrom(ip, ones) + } + } + + for i, rawIp := range rc.Details.Subnets { + if i%2 == 0 { + ip = int2addr(rawIp) + } else { + ones, _ := net.IPMask(int2ip(rawIp)).Size() + nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones) + } + } + + err = nc.validate() + if err != nil { + return nil, err + } + + return &nc, nil +} + +func ip2int(ip []byte) uint32 { + if len(ip) == 16 { + return binary.BigEndian.Uint32(ip[12:16]) + } + return binary.BigEndian.Uint32(ip) +} + +func int2ip(nn uint32) net.IP { + ip := make(net.IP, net.IPv4len) + binary.BigEndian.PutUint32(ip, nn) + return ip +} + +func addr2int(addr netip.Addr) uint32 { + b := addr.Unmap().As4() + return binary.BigEndian.Uint32(b[:]) +} + +func int2addr(nn uint32) netip.Addr { + ip := [4]byte{} + binary.BigEndian.PutUint32(ip[:], nn) + return netip.AddrFrom4(ip).Unmap() +} diff --git a/cert/cert.pb.go b/cert/cert_v1.pb.go similarity index 62% rename from cert/cert.pb.go rename to cert/cert_v1.pb.go index 3570e07..32de1a0 100644 --- a/cert/cert.pb.go +++ b/cert/cert_v1.pb.go @@ -1,8 +1,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.30.0 +// protoc-gen-go v1.34.2 // protoc v3.21.5 -// source: cert.proto +// source: cert_v1.proto package cert @@ -50,11 +50,11 @@ func (x Curve) String() string { } func (Curve) Descriptor() protoreflect.EnumDescriptor { - return file_cert_proto_enumTypes[0].Descriptor() + return file_cert_v1_proto_enumTypes[0].Descriptor() } func (Curve) Type() protoreflect.EnumType { - return &file_cert_proto_enumTypes[0] + return &file_cert_v1_proto_enumTypes[0] } func (x Curve) Number() protoreflect.EnumNumber { @@ -63,7 +63,7 @@ func (x Curve) Number() protoreflect.EnumNumber { // Deprecated: Use Curve.Descriptor instead. func (Curve) EnumDescriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{0} + return file_cert_v1_proto_rawDescGZIP(), []int{0} } type RawNebulaCertificate struct { @@ -78,7 +78,7 @@ type RawNebulaCertificate struct { func (x *RawNebulaCertificate) Reset() { *x = RawNebulaCertificate{} if protoimpl.UnsafeEnabled { - mi := &file_cert_proto_msgTypes[0] + mi := &file_cert_v1_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -91,7 +91,7 @@ func (x *RawNebulaCertificate) String() string { func (*RawNebulaCertificate) ProtoMessage() {} func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message { - mi := &file_cert_proto_msgTypes[0] + mi := &file_cert_v1_proto_msgTypes[0] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -104,7 +104,7 @@ func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message { // Deprecated: Use RawNebulaCertificate.ProtoReflect.Descriptor instead. func (*RawNebulaCertificate) Descriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{0} + return file_cert_v1_proto_rawDescGZIP(), []int{0} } func (x *RawNebulaCertificate) GetDetails() *RawNebulaCertificateDetails { @@ -143,7 +143,7 @@ type RawNebulaCertificateDetails struct { func (x *RawNebulaCertificateDetails) Reset() { *x = RawNebulaCertificateDetails{} if protoimpl.UnsafeEnabled { - mi := &file_cert_proto_msgTypes[1] + mi := &file_cert_v1_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -156,7 +156,7 @@ func (x *RawNebulaCertificateDetails) String() string { func (*RawNebulaCertificateDetails) ProtoMessage() {} func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message { - mi := &file_cert_proto_msgTypes[1] + mi := &file_cert_v1_proto_msgTypes[1] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -169,7 +169,7 @@ func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message { // Deprecated: Use RawNebulaCertificateDetails.ProtoReflect.Descriptor instead. func (*RawNebulaCertificateDetails) Descriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{1} + return file_cert_v1_proto_rawDescGZIP(), []int{1} } func (x *RawNebulaCertificateDetails) GetName() string { @@ -254,7 +254,7 @@ type RawNebulaEncryptedData struct { func (x *RawNebulaEncryptedData) Reset() { *x = RawNebulaEncryptedData{} if protoimpl.UnsafeEnabled { - mi := &file_cert_proto_msgTypes[2] + mi := &file_cert_v1_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -267,7 +267,7 @@ func (x *RawNebulaEncryptedData) String() string { func (*RawNebulaEncryptedData) ProtoMessage() {} func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message { - mi := &file_cert_proto_msgTypes[2] + mi := &file_cert_v1_proto_msgTypes[2] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -280,7 +280,7 @@ func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message { // Deprecated: Use RawNebulaEncryptedData.ProtoReflect.Descriptor instead. func (*RawNebulaEncryptedData) Descriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{2} + return file_cert_v1_proto_rawDescGZIP(), []int{2} } func (x *RawNebulaEncryptedData) GetEncryptionMetadata() *RawNebulaEncryptionMetadata { @@ -309,7 +309,7 @@ type RawNebulaEncryptionMetadata struct { func (x *RawNebulaEncryptionMetadata) Reset() { *x = RawNebulaEncryptionMetadata{} if protoimpl.UnsafeEnabled { - mi := &file_cert_proto_msgTypes[3] + mi := &file_cert_v1_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -322,7 +322,7 @@ func (x *RawNebulaEncryptionMetadata) String() string { func (*RawNebulaEncryptionMetadata) ProtoMessage() {} func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message { - mi := &file_cert_proto_msgTypes[3] + mi := &file_cert_v1_proto_msgTypes[3] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -335,7 +335,7 @@ func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message { // Deprecated: Use RawNebulaEncryptionMetadata.ProtoReflect.Descriptor instead. func (*RawNebulaEncryptionMetadata) Descriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{3} + return file_cert_v1_proto_rawDescGZIP(), []int{3} } func (x *RawNebulaEncryptionMetadata) GetEncryptionAlgorithm() string { @@ -367,7 +367,7 @@ type RawNebulaArgon2Parameters struct { func (x *RawNebulaArgon2Parameters) Reset() { *x = RawNebulaArgon2Parameters{} if protoimpl.UnsafeEnabled { - mi := &file_cert_proto_msgTypes[4] + mi := &file_cert_v1_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -380,7 +380,7 @@ func (x *RawNebulaArgon2Parameters) String() string { func (*RawNebulaArgon2Parameters) ProtoMessage() {} func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message { - mi := &file_cert_proto_msgTypes[4] + mi := &file_cert_v1_proto_msgTypes[4] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -393,7 +393,7 @@ func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message { // Deprecated: Use RawNebulaArgon2Parameters.ProtoReflect.Descriptor instead. func (*RawNebulaArgon2Parameters) Descriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{4} + return file_cert_v1_proto_rawDescGZIP(), []int{4} } func (x *RawNebulaArgon2Parameters) GetVersion() int32 { @@ -431,87 +431,87 @@ func (x *RawNebulaArgon2Parameters) GetSalt() []byte { return nil } -var File_cert_proto protoreflect.FileDescriptor +var File_cert_v1_proto protoreflect.FileDescriptor -var file_cert_proto_rawDesc = []byte{ - 0x0a, 0x0a, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x63, 0x65, - 0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, - 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a, 0x07, 0x44, 0x65, - 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65, - 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, - 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x52, 0x07, - 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69, 0x67, 0x6e, 0x61, - 0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53, 0x69, 0x67, 0x6e, - 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, - 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, - 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x49, 0x70, 0x73, - 0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x53, - 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x07, 0x53, 0x75, - 0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, - 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x1c, 0x0a, - 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x4e, - 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x4e, - 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75, 0x62, 0x6c, 0x69, - 0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50, 0x75, 0x62, 0x6c, - 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, 0x18, 0x08, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, 0x49, 0x73, 0x73, - 0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65, - 0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65, 0x52, 0x05, 0x63, - 0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, - 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, 0x74, 0x61, 0x12, - 0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, - 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65, - 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x12, - 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, - 0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, - 0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, - 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, - 0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, - 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, - 0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, - 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, - 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, - 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x52, - 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, - 0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, - 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12, - 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, - 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x6d, - 0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, - 0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, - 0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, - 0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75, 0x72, 0x76, 0x65, - 0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31, 0x39, 0x10, 0x00, - 0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69, - 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71, - 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x33, +var file_cert_v1_proto_rawDesc = []byte{ + 0x0a, 0x0d, 0x63, 0x65, 0x72, 0x74, 0x5f, 0x76, 0x31, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, + 0x04, 0x63, 0x65, 0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, + 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a, + 0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, + 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, + 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, + 0x73, 0x52, 0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69, + 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53, + 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77, + 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, + 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, + 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18, + 0x0a, 0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52, + 0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75, + 0x70, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, + 0x12, 0x1c, 0x0a, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a, + 0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75, + 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50, + 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, + 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, + 0x49, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, + 0x73, 0x75, 0x65, 0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20, + 0x01, 0x28, 0x0e, 0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65, + 0x52, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, + 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, + 0x74, 0x61, 0x12, 0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, + 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, + 0x61, 0x52, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, + 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, + 0x65, 0x78, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, + 0x72, 0x74, 0x65, 0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, + 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, + 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, + 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, + 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x1f, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, + 0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, + 0x72, 0x73, 0x52, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, + 0x74, 0x65, 0x72, 0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, + 0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, + 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, + 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, + 0x6d, 0x6f, 0x72, 0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, + 0x69, 0x73, 0x6d, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, + 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75, + 0x72, 0x76, 0x65, 0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31, + 0x39, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a, + 0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, + 0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( - file_cert_proto_rawDescOnce sync.Once - file_cert_proto_rawDescData = file_cert_proto_rawDesc + file_cert_v1_proto_rawDescOnce sync.Once + file_cert_v1_proto_rawDescData = file_cert_v1_proto_rawDesc ) -func file_cert_proto_rawDescGZIP() []byte { - file_cert_proto_rawDescOnce.Do(func() { - file_cert_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_proto_rawDescData) +func file_cert_v1_proto_rawDescGZIP() []byte { + file_cert_v1_proto_rawDescOnce.Do(func() { + file_cert_v1_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_v1_proto_rawDescData) }) - return file_cert_proto_rawDescData + return file_cert_v1_proto_rawDescData } -var file_cert_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 5) -var file_cert_proto_goTypes = []interface{}{ +var file_cert_v1_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_cert_v1_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_cert_v1_proto_goTypes = []any{ (Curve)(0), // 0: cert.Curve (*RawNebulaCertificate)(nil), // 1: cert.RawNebulaCertificate (*RawNebulaCertificateDetails)(nil), // 2: cert.RawNebulaCertificateDetails @@ -519,7 +519,7 @@ var file_cert_proto_goTypes = []interface{}{ (*RawNebulaEncryptionMetadata)(nil), // 4: cert.RawNebulaEncryptionMetadata (*RawNebulaArgon2Parameters)(nil), // 5: cert.RawNebulaArgon2Parameters } -var file_cert_proto_depIdxs = []int32{ +var file_cert_v1_proto_depIdxs = []int32{ 2, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails 0, // 1: cert.RawNebulaCertificateDetails.curve:type_name -> cert.Curve 4, // 2: cert.RawNebulaEncryptedData.EncryptionMetadata:type_name -> cert.RawNebulaEncryptionMetadata @@ -531,13 +531,13 @@ var file_cert_proto_depIdxs = []int32{ 0, // [0:4] is the sub-list for field type_name } -func init() { file_cert_proto_init() } -func file_cert_proto_init() { - if File_cert_proto != nil { +func init() { file_cert_v1_proto_init() } +func file_cert_v1_proto_init() { + if File_cert_v1_proto != nil { return } if !protoimpl.UnsafeEnabled { - file_cert_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + file_cert_v1_proto_msgTypes[0].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaCertificate); i { case 0: return &v.state @@ -549,7 +549,7 @@ func file_cert_proto_init() { return nil } } - file_cert_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + file_cert_v1_proto_msgTypes[1].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaCertificateDetails); i { case 0: return &v.state @@ -561,7 +561,7 @@ func file_cert_proto_init() { return nil } } - file_cert_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + file_cert_v1_proto_msgTypes[2].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaEncryptedData); i { case 0: return &v.state @@ -573,7 +573,7 @@ func file_cert_proto_init() { return nil } } - file_cert_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + file_cert_v1_proto_msgTypes[3].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaEncryptionMetadata); i { case 0: return &v.state @@ -585,7 +585,7 @@ func file_cert_proto_init() { return nil } } - file_cert_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + file_cert_v1_proto_msgTypes[4].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaArgon2Parameters); i { case 0: return &v.state @@ -602,19 +602,19 @@ func file_cert_proto_init() { out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_cert_proto_rawDesc, + RawDescriptor: file_cert_v1_proto_rawDesc, NumEnums: 1, NumMessages: 5, NumExtensions: 0, NumServices: 0, }, - GoTypes: file_cert_proto_goTypes, - DependencyIndexes: file_cert_proto_depIdxs, - EnumInfos: file_cert_proto_enumTypes, - MessageInfos: file_cert_proto_msgTypes, + GoTypes: file_cert_v1_proto_goTypes, + DependencyIndexes: file_cert_v1_proto_depIdxs, + EnumInfos: file_cert_v1_proto_enumTypes, + MessageInfos: file_cert_v1_proto_msgTypes, }.Build() - File_cert_proto = out.File - file_cert_proto_rawDesc = nil - file_cert_proto_goTypes = nil - file_cert_proto_depIdxs = nil + File_cert_v1_proto = out.File + file_cert_v1_proto_rawDesc = nil + file_cert_v1_proto_goTypes = nil + file_cert_v1_proto_depIdxs = nil } diff --git a/cert/cert.proto b/cert/cert_v1.proto similarity index 100% rename from cert/cert.proto rename to cert/cert_v1.proto diff --git a/cert/cert_v1_test.go b/cert/cert_v1_test.go new file mode 100644 index 0000000..c687172 --- /dev/null +++ b/cert/cert_v1_test.go @@ -0,0 +1,218 @@ +package cert + +import ( + "fmt" + "net/netip" + "testing" + "time" + + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" +) + +func TestCertificateV1_Marshal(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV1{ + details: detailsV1{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + publicKey: pubKey, + isCA: false, + issuer: "1234567890abcedfghij1234567890ab", + }, + signature: []byte("1234567890abcedfghij1234567890ab"), + } + + b, err := nc.Marshal() + require.NoError(t, err) + //t.Log("Cert size:", len(b)) + + nc2, err := unmarshalCertificateV1(b, nil) + require.NoError(t, err) + + assert.Equal(t, Version1, nc.Version()) + assert.Equal(t, Curve_CURVE25519, nc.Curve()) + assert.Equal(t, nc.Signature(), nc2.Signature()) + assert.Equal(t, nc.Name(), nc2.Name()) + assert.Equal(t, nc.NotBefore(), nc2.NotBefore()) + assert.Equal(t, nc.NotAfter(), nc2.NotAfter()) + assert.Equal(t, nc.PublicKey(), nc2.PublicKey()) + assert.Equal(t, nc.IsCA(), nc2.IsCA()) + + assert.Equal(t, nc.Networks(), nc2.Networks()) + assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks()) + + assert.Equal(t, nc.Groups(), nc2.Groups()) +} + +func TestCertificateV1_Expired(t *testing.T) { + nc := certificateV1{ + details: detailsV1{ + notBefore: time.Now().Add(time.Second * -60).Round(time.Second), + notAfter: time.Now().Add(time.Second * 60).Round(time.Second), + }, + } + + assert.True(t, nc.Expired(time.Now().Add(time.Hour))) + assert.True(t, nc.Expired(time.Now().Add(-time.Hour))) + assert.False(t, nc.Expired(time.Now())) +} + +func TestCertificateV1_MarshalJSON(t *testing.T) { + time.Local = time.UTC + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV1{ + details: detailsV1{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), + notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), + publicKey: pubKey, + isCA: false, + issuer: "1234567890abcedfghij1234567890ab", + }, + signature: []byte("1234567890abcedfghij1234567890ab"), + } + + b, err := nc.MarshalJSON() + require.NoError(t, err) + assert.JSONEq( + t, + "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}", + string(b), + ) +} + +func TestCertificateV1_VerifyPrivateKey(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) + require.NoError(t, err) + + _, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + require.NoError(t, err) + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) + require.Error(t, err) + + c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) + require.NoError(t, err) + assert.Empty(t, b) + assert.Equal(t, Curve_CURVE25519, curve) + err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) + require.NoError(t, err) + + _, priv2 := X25519Keypair() + err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) + require.Error(t, err) +} + +func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + err := ca.VerifyPrivateKey(Curve_P256, caKey) + require.NoError(t, err) + + _, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + require.NoError(t, err) + err = ca.VerifyPrivateKey(Curve_P256, caKey2) + require.Error(t, err) + + c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) + require.NoError(t, err) + assert.Empty(t, b) + assert.Equal(t, Curve_P256, curve) + err = c.VerifyPrivateKey(Curve_P256, rawPriv) + require.NoError(t, err) + + _, priv2 := P256Keypair() + err = c.VerifyPrivateKey(Curve_P256, priv2) + require.Error(t, err) +} + +// Ensure that upgrading the protobuf library does not change how certificates +// are marshalled, since this would break signature verification +func TestMarshalingCertificateV1Consistency(t *testing.T) { + before := time.Date(1970, time.January, 1, 1, 1, 1, 1, time.UTC) + after := time.Date(9999, time.January, 1, 1, 1, 1, 1, time.UTC) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV1{ + details: detailsV1{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.2/16"), + mustParsePrefixUnmapped("10.1.1.1/24"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.3/16"), + mustParsePrefixUnmapped("9.1.1.2/24"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + publicKey: pubKey, + isCA: false, + issuer: "1234567890abcedfghij1234567890ab", + }, + signature: []byte("1234567890abcedfghij1234567890ab"), + } + + b, err := nc.Marshal() + require.NoError(t, err) + assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) + + b, err = proto.Marshal(nc.getRawDetails()) + require.NoError(t, err) + assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) +} + +func TestCertificateV1_Copy(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + cc := c.Copy() + test.AssertDeepCopyEqual(t, c, cc) +} + +func TestUnmarshalCertificateV1(t *testing.T) { + // Test that we don't panic with an invalid certificate (#332) + data := []byte("\x98\x00\x00") + _, err := unmarshalCertificateV1(data, nil) + require.EqualError(t, err, "encoded Details was nil") +} + +func appendByteSlices(b ...[]byte) []byte { + retSlice := []byte{} + for _, v := range b { + retSlice = append(retSlice, v...) + } + return retSlice +} + +func mustParsePrefixUnmapped(s string) netip.Prefix { + prefix := netip.MustParsePrefix(s) + return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits()) +} diff --git a/cert/cert_v2.asn1 b/cert/cert_v2.asn1 new file mode 100644 index 0000000..f863133 --- /dev/null +++ b/cert/cert_v2.asn1 @@ -0,0 +1,37 @@ +Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN + +Name ::= UTF8String (SIZE (1..253)) +Time ::= INTEGER (0..18446744073709551615) -- Seconds since unix epoch, uint64 maximum +Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length +Curve ::= ENUMERATED { + curve25519 (0), + p256 (1) +} + +-- The maximum size of a certificate must not exceed 65536 bytes +Certificate ::= SEQUENCE { + details OCTET STRING, + curve Curve DEFAULT curve25519, + publicKey OCTET STRING, + -- signature(details + curve + publicKey) using the appropriate method for curve + signature OCTET STRING +} + +Details ::= SEQUENCE { + name Name, + + -- At least 1 ipv4 or ipv6 address must be present if isCA is false + networks SEQUENCE OF Network OPTIONAL, + unsafeNetworks SEQUENCE OF Network OPTIONAL, + groups SEQUENCE OF Name OPTIONAL, + isCA BOOLEAN DEFAULT false, + notBefore Time, + notAfter Time, + + -- issuer is only required if isCA is false, if isCA is true then it must not be present + issuer OCTET STRING OPTIONAL, + ... + -- New fields can be added below here +} + +END \ No newline at end of file diff --git a/cert/cert_v2.go b/cert/cert_v2.go new file mode 100644 index 0000000..322463e --- /dev/null +++ b/cert/cert_v2.go @@ -0,0 +1,730 @@ +package cert + +import ( + "bytes" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "encoding/pem" + "fmt" + "net/netip" + "slices" + "time" + + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" + "golang.org/x/crypto/curve25519" +) + +const ( + classConstructed = 0x20 + classContextSpecific = 0x80 + + TagCertDetails = 0 | classConstructed | classContextSpecific + TagCertCurve = 1 | classContextSpecific + TagCertPublicKey = 2 | classContextSpecific + TagCertSignature = 3 | classContextSpecific + + TagDetailsName = 0 | classContextSpecific + TagDetailsNetworks = 1 | classConstructed | classContextSpecific + TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific + TagDetailsGroups = 3 | classConstructed | classContextSpecific + TagDetailsIsCA = 4 | classContextSpecific + TagDetailsNotBefore = 5 | classContextSpecific + TagDetailsNotAfter = 6 | classContextSpecific + TagDetailsIssuer = 7 | classContextSpecific +) + +const ( + // MaxCertificateSize is the maximum length a valid certificate can be + MaxCertificateSize = 65536 + + // MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems + MaxNameLength = 253 + + // MaxNetworkLength is the maximum length a network value can be. + // 16 bytes for an ipv6 address + 1 byte for the prefix length + MaxNetworkLength = 17 +) + +type certificateV2 struct { + details detailsV2 + + // RawDetails contains the entire asn.1 DER encoded Details struct + // This is to benefit forwards compatibility in signature checking. + // signature(RawDetails + Curve + PublicKey) == Signature + rawDetails []byte + curve Curve + publicKey []byte + signature []byte +} + +type detailsV2 struct { + name string + networks []netip.Prefix // MUST BE SORTED + unsafeNetworks []netip.Prefix // MUST BE SORTED + groups []string + isCA bool + notBefore time.Time + notAfter time.Time + issuer string +} + +func (c *certificateV2) Version() Version { + return Version2 +} + +func (c *certificateV2) Curve() Curve { + return c.curve +} + +func (c *certificateV2) Groups() []string { + return c.details.groups +} + +func (c *certificateV2) IsCA() bool { + return c.details.isCA +} + +func (c *certificateV2) Issuer() string { + return c.details.issuer +} + +func (c *certificateV2) Name() string { + return c.details.name +} + +func (c *certificateV2) Networks() []netip.Prefix { + return c.details.networks +} + +func (c *certificateV2) NotAfter() time.Time { + return c.details.notAfter +} + +func (c *certificateV2) NotBefore() time.Time { + return c.details.notBefore +} + +func (c *certificateV2) PublicKey() []byte { + return c.publicKey +} + +func (c *certificateV2) Signature() []byte { + return c.signature +} + +func (c *certificateV2) UnsafeNetworks() []netip.Prefix { + return c.details.unsafeNetworks +} + +func (c *certificateV2) Fingerprint() (string, error) { + if len(c.rawDetails) == 0 { + return "", ErrMissingDetails + } + + b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)+len(c.signature)) + copy(b, c.rawDetails) + b[len(c.rawDetails)] = byte(c.curve) + copy(b[len(c.rawDetails)+1:], c.publicKey) + copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature) + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]), nil +} + +func (c *certificateV2) CheckSignature(key []byte) bool { + if len(c.rawDetails) == 0 { + return false + } + b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)) + copy(b, c.rawDetails) + b[len(c.rawDetails)] = byte(c.curve) + copy(b[len(c.rawDetails)+1:], c.publicKey) + + switch c.curve { + case Curve_CURVE25519: + return ed25519.Verify(key, b, c.signature) + case Curve_P256: + x, y := elliptic.Unmarshal(elliptic.P256(), key) + pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} + hashed := sha256.Sum256(b) + return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) + default: + return false + } +} + +func (c *certificateV2) Expired(t time.Time) bool { + return c.details.notBefore.After(t) || c.details.notAfter.Before(t) +} + +func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error { + if curve != c.curve { + return ErrPublicPrivateCurveMismatch + } + if c.details.isCA { + switch curve { + case Curve_CURVE25519: + // the call to PublicKey below will panic slice bounds out of range otherwise + if len(key) != ed25519.PrivateKeySize { + return ErrInvalidPrivateKey + } + + if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) { + return ErrPublicPrivateKeyMismatch + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return ErrInvalidPrivateKey + } + pub := privkey.PublicKey().Bytes() + if !bytes.Equal(pub, c.publicKey) { + return ErrPublicPrivateKeyMismatch + } + default: + return fmt.Errorf("invalid curve: %s", curve) + } + return nil + } + + var pub []byte + switch curve { + case Curve_CURVE25519: + var err error + pub, err = curve25519.X25519(key, curve25519.Basepoint) + if err != nil { + return ErrInvalidPrivateKey + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return ErrInvalidPrivateKey + } + pub = privkey.PublicKey().Bytes() + default: + return fmt.Errorf("invalid curve: %s", curve) + } + if !bytes.Equal(pub, c.publicKey) { + return ErrPublicPrivateKeyMismatch + } + + return nil +} + +func (c *certificateV2) String() string { + mb, err := c.marshalJSON() + if err != nil { + return fmt.Sprintf("", err) + } + + b, err := json.MarshalIndent(mb, "", "\t") + if err != nil { + return fmt.Sprintf("", err) + } + return string(b) +} + +func (c *certificateV2) MarshalForHandshakes() ([]byte, error) { + if c.rawDetails == nil { + return nil, ErrEmptyRawDetails + } + var b cryptobyte.Builder + // Outermost certificate + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + + // Add the cert details which is already marshalled + b.AddBytes(c.rawDetails) + + // Skipping the curve and public key since those come across in a different part of the handshake + + // Add the signature + b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) { + b.AddBytes(c.signature) + }) + }) + + return b.Bytes() +} + +func (c *certificateV2) Marshal() ([]byte, error) { + if c.rawDetails == nil { + return nil, ErrEmptyRawDetails + } + var b cryptobyte.Builder + // Outermost certificate + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + + // Add the cert details which is already marshalled + b.AddBytes(c.rawDetails) + + // Add the curve only if its not the default value + if c.curve != Curve_CURVE25519 { + b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) { + b.AddBytes([]byte{byte(c.curve)}) + }) + } + + // Add the public key if it is not empty + if c.publicKey != nil { + b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) { + b.AddBytes(c.publicKey) + }) + } + + // Add the signature + b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) { + b.AddBytes(c.signature) + }) + }) + + return b.Bytes() +} + +func (c *certificateV2) MarshalPEM() ([]byte, error) { + b, err := c.Marshal() + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil +} + +func (c *certificateV2) MarshalJSON() ([]byte, error) { + b, err := c.marshalJSON() + if err != nil { + return nil, err + } + return json.Marshal(b) +} + +func (c *certificateV2) marshalJSON() (m, error) { + fp, err := c.Fingerprint() + if err != nil { + return nil, err + } + + return m{ + "details": m{ + "name": c.details.name, + "networks": c.details.networks, + "unsafeNetworks": c.details.unsafeNetworks, + "groups": c.details.groups, + "notBefore": c.details.notBefore, + "notAfter": c.details.notAfter, + "isCa": c.details.isCA, + "issuer": c.details.issuer, + }, + "version": Version2, + "publicKey": fmt.Sprintf("%x", c.publicKey), + "curve": c.curve.String(), + "fingerprint": fp, + "signature": fmt.Sprintf("%x", c.Signature()), + }, nil +} + +func (c *certificateV2) Copy() Certificate { + nc := &certificateV2{ + details: detailsV2{ + name: c.details.name, + notBefore: c.details.notBefore, + notAfter: c.details.notAfter, + isCA: c.details.isCA, + issuer: c.details.issuer, + }, + curve: c.curve, + publicKey: make([]byte, len(c.publicKey)), + signature: make([]byte, len(c.signature)), + rawDetails: make([]byte, len(c.rawDetails)), + } + + if c.details.groups != nil { + nc.details.groups = make([]string, len(c.details.groups)) + copy(nc.details.groups, c.details.groups) + } + + if c.details.networks != nil { + nc.details.networks = make([]netip.Prefix, len(c.details.networks)) + copy(nc.details.networks, c.details.networks) + } + + if c.details.unsafeNetworks != nil { + nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks)) + copy(nc.details.unsafeNetworks, c.details.unsafeNetworks) + } + + copy(nc.rawDetails, c.rawDetails) + copy(nc.signature, c.signature) + copy(nc.publicKey, c.publicKey) + + return nc +} + +func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error { + c.details = detailsV2{ + name: t.Name, + networks: t.Networks, + unsafeNetworks: t.UnsafeNetworks, + groups: t.Groups, + isCA: t.IsCA, + notBefore: t.NotBefore, + notAfter: t.NotAfter, + issuer: t.issuer, + } + c.curve = t.Curve + c.publicKey = t.PublicKey + return c.validate() +} + +func (c *certificateV2) validate() error { + // Empty names are allowed + + if len(c.publicKey) == 0 { + return ErrInvalidPublicKey + } + + if !c.details.isCA && len(c.details.networks) == 0 { + return NewErrInvalidCertificateProperties("non-CA certificate must contain at least 1 network") + } + + hasV4Networks := false + hasV6Networks := false + for _, network := range c.details.networks { + if !network.IsValid() || !network.Addr().IsValid() { + return NewErrInvalidCertificateProperties("invalid network: %s", network) + } + + if network.Addr().IsUnspecified() { + return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network) + } + + if network.Addr().Zone() != "" { + return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network) + } + + if network.Addr().Is4In6() { + return NewErrInvalidCertificateProperties("4in6 networks are not allowed: %s", network) + } + + hasV4Networks = hasV4Networks || network.Addr().Is4() + hasV6Networks = hasV6Networks || network.Addr().Is6() + } + + slices.SortFunc(c.details.networks, comparePrefix) + err := findDuplicatePrefix(c.details.networks) + if err != nil { + return err + } + + for _, network := range c.details.unsafeNetworks { + if !network.IsValid() || !network.Addr().IsValid() { + return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network) + } + + if network.Addr().Zone() != "" { + return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network) + } + + if !c.details.isCA { + if network.Addr().Is6() { + if !hasV6Networks { + return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network) + } + } else if network.Addr().Is4() { + if !hasV4Networks { + return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network) + } + } + } + } + + slices.SortFunc(c.details.unsafeNetworks, comparePrefix) + err = findDuplicatePrefix(c.details.unsafeNetworks) + if err != nil { + return err + } + + return nil +} + +func (c *certificateV2) marshalForSigning() ([]byte, error) { + d, err := c.details.Marshal() + if err != nil { + return nil, fmt.Errorf("marshalling certificate details failed: %w", err) + } + c.rawDetails = d + + b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)) + copy(b, c.rawDetails) + b[len(c.rawDetails)] = byte(c.curve) + copy(b[len(c.rawDetails)+1:], c.publicKey) + return b, nil +} + +func (c *certificateV2) setSignature(b []byte) error { + if len(b) == 0 { + return ErrEmptySignature + } + c.signature = b + return nil +} + +func (d *detailsV2) Marshal() ([]byte, error) { + var b cryptobyte.Builder + var err error + + // Details are a structure + b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) { + + // Add the name + b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) { + b.AddBytes([]byte(d.name)) + }) + + // Add the networks if any exist + if len(d.networks) > 0 { + b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) { + for _, n := range d.networks { + sb, innerErr := n.MarshalBinary() + if innerErr != nil { + // MarshalBinary never returns an error + err = fmt.Errorf("unable to marshal network: %w", innerErr) + return + } + b.AddASN1OctetString(sb) + } + }) + } + + // Add the unsafe networks if any exist + if len(d.unsafeNetworks) > 0 { + b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) { + for _, n := range d.unsafeNetworks { + sb, innerErr := n.MarshalBinary() + if innerErr != nil { + // MarshalBinary never returns an error + err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr) + return + } + b.AddASN1OctetString(sb) + } + }) + } + + // Add groups if any exist + if len(d.groups) > 0 { + b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) { + for _, group := range d.groups { + b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) { + b.AddBytes([]byte(group)) + }) + } + }) + } + + // Add IsCA only if true + if d.isCA { + b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) { + b.AddUint8(0xff) + }) + } + + // Add not before + b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore) + + // Add not after + b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter) + + // Add the issuer if present + if d.issuer != "" { + issuerBytes, innerErr := hex.DecodeString(d.issuer) + if innerErr != nil { + err = fmt.Errorf("failed to decode issuer: %w", innerErr) + return + } + b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) { + b.AddBytes(issuerBytes) + }) + } + }) + + if err != nil { + return nil, err + } + + return b.Bytes() +} + +func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) { + l := len(b) + if l == 0 || l > MaxCertificateSize { + return nil, ErrBadFormat + } + + input := cryptobyte.String(b) + // Open the envelope + if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() { + return nil, ErrBadFormat + } + + // Grab the cert details, we need to preserve the tag and length + var rawDetails cryptobyte.String + if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() { + return nil, ErrBadFormat + } + + //Maybe grab the curve + var rawCurve byte + if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) { + return nil, ErrBadFormat + } + curve = Curve(rawCurve) + + // Maybe grab the public key + var rawPublicKey cryptobyte.String + if len(publicKey) > 0 { + rawPublicKey = publicKey + } else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) { + return nil, ErrBadFormat + } + + if len(rawPublicKey) == 0 { + return nil, ErrBadFormat + } + + // Grab the signature + var rawSignature cryptobyte.String + if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() { + return nil, ErrBadFormat + } + + // Finally unmarshal the details + details, err := unmarshalDetails(rawDetails) + if err != nil { + return nil, err + } + + c := &certificateV2{ + details: details, + rawDetails: rawDetails, + curve: curve, + publicKey: rawPublicKey, + signature: rawSignature, + } + + err = c.validate() + if err != nil { + return nil, err + } + + return c, nil +} + +func unmarshalDetails(b cryptobyte.String) (detailsV2, error) { + // Open the envelope + if !b.ReadASN1(&b, TagCertDetails) || b.Empty() { + return detailsV2{}, ErrBadFormat + } + + // Read the name + var name cryptobyte.String + if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength { + return detailsV2{}, ErrBadFormat + } + + // Read the network addresses + var subString cryptobyte.String + var found bool + + if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) { + return detailsV2{}, ErrBadFormat + } + + var networks []netip.Prefix + var val cryptobyte.String + if found { + for !subString.Empty() { + if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength { + return detailsV2{}, ErrBadFormat + } + + var n netip.Prefix + if err := n.UnmarshalBinary(val); err != nil { + return detailsV2{}, ErrBadFormat + } + networks = append(networks, n) + } + } + + // Read out any unsafe networks + if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) { + return detailsV2{}, ErrBadFormat + } + + var unsafeNetworks []netip.Prefix + if found { + for !subString.Empty() { + if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength { + return detailsV2{}, ErrBadFormat + } + + var n netip.Prefix + if err := n.UnmarshalBinary(val); err != nil { + return detailsV2{}, ErrBadFormat + } + unsafeNetworks = append(unsafeNetworks, n) + } + } + + // Read out any groups + if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) { + return detailsV2{}, ErrBadFormat + } + + var groups []string + if found { + for !subString.Empty() { + if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() { + return detailsV2{}, ErrBadFormat + } + groups = append(groups, string(val)) + } + } + + // Read out IsCA + var isCa bool + if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) { + return detailsV2{}, ErrBadFormat + } + + // Read not before and not after + var notBefore int64 + if !b.ReadASN1Int64WithTag(¬Before, TagDetailsNotBefore) { + return detailsV2{}, ErrBadFormat + } + + var notAfter int64 + if !b.ReadASN1Int64WithTag(¬After, TagDetailsNotAfter) { + return detailsV2{}, ErrBadFormat + } + + // Read issuer + var issuer cryptobyte.String + if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) { + return detailsV2{}, ErrBadFormat + } + + return detailsV2{ + name: string(name), + networks: networks, + unsafeNetworks: unsafeNetworks, + groups: groups, + isCA: isCa, + notBefore: time.Unix(notBefore, 0), + notAfter: time.Unix(notAfter, 0), + issuer: hex.EncodeToString(issuer), + }, nil +} diff --git a/cert/cert_v2_test.go b/cert/cert_v2_test.go new file mode 100644 index 0000000..c84f8c9 --- /dev/null +++ b/cert/cert_v2_test.go @@ -0,0 +1,267 @@ +package cert + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/hex" + "net/netip" + "slices" + "testing" + "time" + + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCertificateV2_Marshal(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV2{ + details: detailsV2{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.2/16"), + mustParsePrefixUnmapped("10.1.1.1/24"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.3/16"), + mustParsePrefixUnmapped("9.1.1.2/24"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + isCA: false, + issuer: "1234567890abcdef1234567890abcdef", + }, + signature: []byte("1234567890abcdef1234567890abcdef"), + publicKey: pubKey, + } + + db, err := nc.details.Marshal() + require.NoError(t, err) + nc.rawDetails = db + + b, err := nc.Marshal() + require.NoError(t, err) + //t.Log("Cert size:", len(b)) + + nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519) + require.NoError(t, err) + + assert.Equal(t, Version2, nc.Version()) + assert.Equal(t, Curve_CURVE25519, nc.Curve()) + assert.Equal(t, nc.Signature(), nc2.Signature()) + assert.Equal(t, nc.Name(), nc2.Name()) + assert.Equal(t, nc.NotBefore(), nc2.NotBefore()) + assert.Equal(t, nc.NotAfter(), nc2.NotAfter()) + assert.Equal(t, nc.PublicKey(), nc2.PublicKey()) + assert.Equal(t, nc.IsCA(), nc2.IsCA()) + assert.Equal(t, nc.Issuer(), nc2.Issuer()) + + // unmarshalling will sort networks and unsafeNetworks, we need to do the same + // but first make sure it fails + assert.NotEqual(t, nc.Networks(), nc2.Networks()) + assert.NotEqual(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks()) + + slices.SortFunc(nc.details.networks, comparePrefix) + slices.SortFunc(nc.details.unsafeNetworks, comparePrefix) + + assert.Equal(t, nc.Networks(), nc2.Networks()) + assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks()) + + assert.Equal(t, nc.Groups(), nc2.Groups()) +} + +func TestCertificateV2_Expired(t *testing.T) { + nc := certificateV2{ + details: detailsV2{ + notBefore: time.Now().Add(time.Second * -60).Round(time.Second), + notAfter: time.Now().Add(time.Second * 60).Round(time.Second), + }, + } + + assert.True(t, nc.Expired(time.Now().Add(time.Hour))) + assert.True(t, nc.Expired(time.Now().Add(-time.Hour))) + assert.False(t, nc.Expired(time.Now())) +} + +func TestCertificateV2_MarshalJSON(t *testing.T) { + time.Local = time.UTC + pubKey := []byte("1234567890abcedf1234567890abcedf") + + nc := certificateV2{ + details: detailsV2{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), + notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), + isCA: false, + issuer: "1234567890abcedf1234567890abcedf", + }, + publicKey: pubKey, + signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"), + } + + b, err := nc.MarshalJSON() + require.ErrorIs(t, err, ErrMissingDetails) + + rd, err := nc.details.Marshal() + require.NoError(t, err) + + nc.rawDetails = rd + b, err = nc.MarshalJSON() + require.NoError(t, err) + assert.JSONEq( + t, + "{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}", + string(b), + ) +} + +func TestCertificateV2_VerifyPrivateKey(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) + require.NoError(t, err) + + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16]) + require.ErrorIs(t, err, ErrInvalidPrivateKey) + + _, caKey2, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) + require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) + + c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) + require.NoError(t, err) + assert.Empty(t, b) + assert.Equal(t, Curve_CURVE25519, curve) + err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) + require.NoError(t, err) + + _, priv2 := X25519Keypair() + err = c.VerifyPrivateKey(Curve_P256, priv2) + require.ErrorIs(t, err, ErrPublicPrivateCurveMismatch) + + err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) + require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) + + err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16]) + require.ErrorIs(t, err, ErrInvalidPrivateKey) + + ac, ok := c.(*certificateV2) + require.True(t, ok) + ac.curve = Curve(99) + err = c.VerifyPrivateKey(Curve(99), priv2) + require.EqualError(t, err, "invalid curve: 99") + + ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) + require.NoError(t, err) + + err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16]) + require.ErrorIs(t, err, ErrInvalidPrivateKey) + + c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv) + + err = c.VerifyPrivateKey(Curve_P256, priv[:16]) + require.ErrorIs(t, err, ErrInvalidPrivateKey) + + err = c.VerifyPrivateKey(Curve_P256, priv) + require.ErrorIs(t, err, ErrInvalidPrivateKey) + + aCa, ok := ca2.(*certificateV2) + require.True(t, ok) + aCa.curve = Curve(99) + err = aCa.VerifyPrivateKey(Curve(99), priv2) + require.EqualError(t, err, "invalid curve: 99") + +} + +func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + err := ca.VerifyPrivateKey(Curve_P256, caKey) + require.NoError(t, err) + + _, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + require.NoError(t, err) + err = ca.VerifyPrivateKey(Curve_P256, caKey2) + require.Error(t, err) + + c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) + require.NoError(t, err) + assert.Empty(t, b) + assert.Equal(t, Curve_P256, curve) + err = c.VerifyPrivateKey(Curve_P256, rawPriv) + require.NoError(t, err) + + _, priv2 := P256Keypair() + err = c.VerifyPrivateKey(Curve_P256, priv2) + require.Error(t, err) +} + +func TestCertificateV2_Copy(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + cc := c.Copy() + test.AssertDeepCopyEqual(t, c, cc) +} + +func TestUnmarshalCertificateV2(t *testing.T) { + data := []byte("\x98\x00\x00") + _, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519) + require.EqualError(t, err, "bad wire format") +} + +func TestCertificateV2_marshalForSigningStability(t *testing.T) { + before := time.Date(1996, time.May, 5, 0, 0, 0, 0, time.UTC) + after := before.Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV2{ + details: detailsV2{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.2/16"), + mustParsePrefixUnmapped("10.1.1.1/24"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.3/16"), + mustParsePrefixUnmapped("9.1.1.2/24"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + isCA: false, + issuer: "1234567890abcdef1234567890abcdef", + }, + signature: []byte("1234567890abcdef1234567890abcdef"), + publicKey: pubKey, + } + + const expectedRawDetailsStr = "a070800774657374696e67a10e04050a0101021004050a01010118a20e0405090101031004050901010218a3270c0b746573742d67726f7570310c0b746573742d67726f7570320c0b746573742d67726f7570338504318bef808604318befbc87101234567890abcdef1234567890abcdef" + expectedRawDetails, err := hex.DecodeString(expectedRawDetailsStr) + require.NoError(t, err) + + db, err := nc.details.Marshal() + require.NoError(t, err) + assert.Equal(t, expectedRawDetails, db) + + expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162") + b, err := nc.marshalForSigning() + require.NoError(t, err) + assert.Equal(t, expectedForSigning, b) +} diff --git a/cert/crypto.go b/cert/crypto.go index 3558e1a..4c236ae 100644 --- a/cert/crypto.go +++ b/cert/crypto.go @@ -3,14 +3,28 @@ package cert import ( "crypto/aes" "crypto/cipher" + "crypto/ed25519" "crypto/rand" + "encoding/pem" "fmt" "io" + "math" "golang.org/x/crypto/argon2" + "google.golang.org/protobuf/proto" ) -// KDF factors +type NebulaEncryptedData struct { + EncryptionMetadata NebulaEncryptionMetadata + Ciphertext []byte +} + +type NebulaEncryptionMetadata struct { + EncryptionAlgorithm string + Argon2Parameters Argon2Parameters +} + +// Argon2Parameters KDF factors type Argon2Parameters struct { version rune Memory uint32 // KiB @@ -19,7 +33,7 @@ type Argon2Parameters struct { salt []byte } -// Returns a new Argon2Parameters object with current version set +// NewArgon2Parameters Returns a new Argon2Parameters object with current version set func NewArgon2Parameters(memory uint32, parallelism uint8, iterations uint32) *Argon2Parameters { return &Argon2Parameters{ version: argon2.Version, @@ -141,3 +155,146 @@ func splitNonceCiphertext(blob []byte, nonceSize int) ([]byte, []byte, error) { return blob[:nonceSize], blob[nonceSize:], nil } + +// EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key +func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) { + ciphertext, err := aes256Encrypt(passphrase, kdfParams, b) + if err != nil { + return nil, err + } + + b, err = proto.Marshal(&RawNebulaEncryptedData{ + EncryptionMetadata: &RawNebulaEncryptionMetadata{ + EncryptionAlgorithm: "AES-256-GCM", + Argon2Parameters: &RawNebulaArgon2Parameters{ + Version: kdfParams.version, + Memory: kdfParams.Memory, + Parallelism: uint32(kdfParams.Parallelism), + Iterations: kdfParams.Iterations, + Salt: kdfParams.salt, + }, + }, + Ciphertext: ciphertext, + }) + if err != nil { + return nil, err + } + + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil + default: + return nil, fmt.Errorf("invalid curve: %v", curve) + } +} + +// UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its +// protobuf-generated struct. +func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) { + if len(b) == 0 { + return nil, fmt.Errorf("nil byte array") + } + var rned RawNebulaEncryptedData + err := proto.Unmarshal(b, &rned) + if err != nil { + return nil, err + } + + if rned.EncryptionMetadata == nil { + return nil, fmt.Errorf("encoded EncryptionMetadata was nil") + } + + if rned.EncryptionMetadata.Argon2Parameters == nil { + return nil, fmt.Errorf("encoded Argon2Parameters was nil") + } + + params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters) + if err != nil { + return nil, err + } + + ned := NebulaEncryptedData{ + EncryptionMetadata: NebulaEncryptionMetadata{ + EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm, + Argon2Parameters: *params, + }, + Ciphertext: rned.Ciphertext, + } + + return &ned, nil +} + +func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) { + if params.Version < math.MinInt32 || params.Version > math.MaxInt32 { + return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32) + } + if params.Memory <= 0 || params.Memory > math.MaxUint32 { + return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32)) + } + if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 { + return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8) + } + if params.Iterations <= 0 || params.Iterations > math.MaxUint32 { + return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32)) + } + + return &Argon2Parameters{ + version: params.Version, + Memory: params.Memory, + Parallelism: uint8(params.Parallelism), + Iterations: params.Iterations, + salt: params.Salt, + }, nil + +} + +// DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with +// the given passphrase, returning any other bytes b or an error on failure +func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) { + var curve Curve + + k, r := pem.Decode(b) + if k == nil { + return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") + } + + switch k.Type { + case EncryptedEd25519PrivateKeyBanner: + curve = Curve_CURVE25519 + case EncryptedECDSAP256PrivateKeyBanner: + curve = Curve_P256 + default: + return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") + } + + ned, err := UnmarshalNebulaEncryptedData(k.Bytes) + if err != nil { + return curve, nil, r, err + } + + var bytes []byte + switch ned.EncryptionMetadata.EncryptionAlgorithm { + case "AES-256-GCM": + bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext) + if err != nil { + return curve, nil, r, err + } + default: + return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm) + } + + switch curve { + case Curve_CURVE25519: + if len(bytes) != ed25519.PrivateKeySize { + return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize) + } + case Curve_P256: + if len(bytes) != 32 { + return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key") + } + } + + return curve, bytes, r, nil +} diff --git a/cert/crypto_test.go b/cert/crypto_test.go index c2e61df..6358ba6 100644 --- a/cert/crypto_test.go +++ b/cert/crypto_test.go @@ -4,22 +4,110 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/crypto/argon2" ) func TestNewArgon2Parameters(t *testing.T) { p := NewArgon2Parameters(64*1024, 4, 3) - assert.EqualValues(t, &Argon2Parameters{ + assert.Equal(t, &Argon2Parameters{ version: argon2.Version, Memory: 64 * 1024, Parallelism: 4, Iterations: 3, }, p) p = NewArgon2Parameters(2*1024*1024, 2, 1) - assert.EqualValues(t, &Argon2Parameters{ + assert.Equal(t, &Argon2Parameters{ version: argon2.Version, Memory: 2 * 1024 * 1024, Parallelism: 2, Iterations: 1, }, p) } + +func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) { + passphrase := []byte("DO NOT USE THIS KEY") + privKey := []byte(`# A good key +-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT +oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl ++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB +qrlJ69wer3ZUHFXA +-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + shortKey := []byte(`# A key which, once decrypted, is too short +-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7 +k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe +GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs +rQr3bdH3Oy/WiYU= +-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + invalidBanner := []byte(`# Invalid banner (not encrypted) +-----BEGIN NEBULA ED25519 PRIVATE KEY----- +bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG +XgLvodMXZJuaFPssp+WwtA== +-----END NEBULA ED25519 PRIVATE KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT +oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl ++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB +qrlJ69wer3ZUHFXA +-END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + + keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem) + + // Success test case + curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle) + require.NoError(t, err) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Len(t, k, 64) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + + // Fail due to short key + curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) + require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") + assert.Nil(t, k) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + + // Fail due to invalid banner + curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) + require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) + require.EqualError(t, err, "input did not contain a valid PEM encoded block") + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + + // Fail due to invalid passphrase + curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) + require.EqualError(t, err, "invalid passphrase or corrupt private key") + assert.Nil(t, k) + assert.Equal(t, []byte{}, rest) +} + +func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { + // Having proved that decryption works correctly above, we can test the + // encryption function produces a value which can be decrypted + passphrase := []byte("passphrase") + bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") + kdfParams := NewArgon2Parameters(64*1024, 4, 3) + key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams) + require.NoError(t, err) + + // Verify the "key" can be decrypted successfully + curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key) + assert.Len(t, k, 64) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Equal(t, []byte{}, rest) + require.NoError(t, err) + + // EncryptAndMarshalEd25519PrivateKey does not create any errors itself +} diff --git a/cert/errors.go b/cert/errors.go index 05b42d1..4bbc023 100644 --- a/cert/errors.go +++ b/cert/errors.go @@ -2,13 +2,48 @@ package cert import ( "errors" + "fmt" ) var ( - ErrRootExpired = errors.New("root certificate is expired") - ErrExpired = errors.New("certificate is expired") - ErrNotCA = errors.New("certificate is not a CA") - ErrNotSelfSigned = errors.New("certificate is not self-signed") - ErrBlockListed = errors.New("certificate is in the block list") - ErrSignatureMismatch = errors.New("certificate signature did not match") + ErrBadFormat = errors.New("bad wire format") + ErrRootExpired = errors.New("root certificate is expired") + ErrExpired = errors.New("certificate is expired") + ErrNotCA = errors.New("certificate is not a CA") + ErrNotSelfSigned = errors.New("certificate is not self-signed") + ErrBlockListed = errors.New("certificate is in the block list") + ErrFingerprintMismatch = errors.New("certificate fingerprint did not match") + ErrSignatureMismatch = errors.New("certificate signature did not match") + ErrInvalidPublicKey = errors.New("invalid public key") + ErrInvalidPrivateKey = errors.New("invalid private key") + ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve") + ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair") + ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") + ErrCaNotFound = errors.New("could not find ca for the certificate") + + ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block") + ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner") + ErrInvalidPEMX25519PublicKeyBanner = errors.New("bytes did not contain a proper X25519 public key banner") + ErrInvalidPEMX25519PrivateKeyBanner = errors.New("bytes did not contain a proper X25519 private key banner") + ErrInvalidPEMEd25519PublicKeyBanner = errors.New("bytes did not contain a proper Ed25519 public key banner") + ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner") + + ErrNoPeerStaticKey = errors.New("no peer static key was present") + ErrNoPayload = errors.New("provided payload was empty") + + ErrMissingDetails = errors.New("certificate did not contain details") + ErrEmptySignature = errors.New("empty signature") + ErrEmptyRawDetails = errors.New("empty rawDetails not allowed") ) + +type ErrInvalidCertificateProperties struct { + str string +} + +func NewErrInvalidCertificateProperties(format string, a ...any) error { + return &ErrInvalidCertificateProperties{fmt.Sprintf(format, a...)} +} + +func (e *ErrInvalidCertificateProperties) Error() string { + return e.str +} diff --git a/cert/helper_test.go b/cert/helper_test.go new file mode 100644 index 0000000..1b72a0f --- /dev/null +++ b/cert/helper_test.go @@ -0,0 +1,141 @@ +package cert + +import ( + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "io" + "net/netip" + "time" + + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/ed25519" +) + +// NewTestCaCert will create a new ca certificate +func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { + var err error + var pub, priv []byte + + switch curve { + case Curve_CURVE25519: + pub, priv, err = ed25519.GenerateKey(rand.Reader) + case Curve_P256: + privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + + pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y) + priv = privk.D.FillBytes(make([]byte, 32)) + default: + // There is no default to allow the underlying lib to respond with an error + } + + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + t := &TBSCertificate{ + Curve: curve, + Version: version, + Name: "test ca", + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + IsCA: true, + } + + c, err := t.Sign(nil, curve, priv) + if err != nil { + panic(err) + } + + pem, err := c.MarshalPEM() + if err != nil { + panic(err) + } + + return c, pub, priv, pem +} + +// NewTestCert will generate a signed certificate with the provided details. +// Expiry times are defaulted if you do not pass them in +func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + if len(networks) == 0 { + networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")} + } + + var pub, priv []byte + switch curve { + case Curve_CURVE25519: + pub, priv = X25519Keypair() + case Curve_P256: + pub, priv = P256Keypair() + default: + panic("unknown curve") + } + + nc := &TBSCertificate{ + Version: v, + Curve: curve, + Name: name, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: false, + } + + c, err := nc.Sign(ca, ca.Curve(), key) + if err != nil { + panic(err) + } + + pem, err := c.MarshalPEM() + if err != nil { + panic(err) + } + + return c, pub, MarshalPrivateKeyToPEM(curve, priv), pem +} + +func X25519Keypair() ([]byte, []byte) { + privkey := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, privkey); err != nil { + panic(err) + } + + pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) + if err != nil { + panic(err) + } + + return pubkey, privkey +} + +func P256Keypair() ([]byte, []byte) { + privkey, err := ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + pubkey := privkey.PublicKey() + return pubkey.Bytes(), privkey.Bytes() +} diff --git a/cert/pem.go b/cert/pem.go new file mode 100644 index 0000000..7ad28d1 --- /dev/null +++ b/cert/pem.go @@ -0,0 +1,161 @@ +package cert + +import ( + "encoding/pem" + "fmt" + + "golang.org/x/crypto/ed25519" +) + +const ( + CertificateBanner = "NEBULA CERTIFICATE" + CertificateV2Banner = "NEBULA CERTIFICATE V2" + X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" + X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" + EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY" + Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" + Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" + + P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY" + P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY" + EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY" + ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY" +) + +// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed +// data or an error on failure +func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { + p, r := pem.Decode(b) + if p == nil { + return nil, r, ErrInvalidPEMBlock + } + + var c Certificate + var err error + + switch p.Type { + // Implementations must validate the resulting certificate contains valid information + case CertificateBanner: + c, err = unmarshalCertificateV1(p.Bytes, nil) + case CertificateV2Banner: + c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519) + default: + return nil, r, ErrInvalidPEMCertificateBanner + } + + if err != nil { + return nil, r, err + } + + return c, r, nil + +} + +func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte { + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b}) + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b}) + default: + return nil + } +} + +func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") + } + var expectedLen int + var curve Curve + switch k.Type { + case X25519PublicKeyBanner, Ed25519PublicKeyBanner: + expectedLen = 32 + curve = Curve_CURVE25519 + case P256PublicKeyBanner: + // Uncompressed + expectedLen = 65 + curve = Curve_P256 + default: + return nil, r, 0, fmt.Errorf("bytes did not contain a proper public key banner") + } + if len(k.Bytes) != expectedLen { + return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve) + } + return k.Bytes, r, curve, nil +} + +func MarshalPrivateKeyToPEM(curve Curve, b []byte) []byte { + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b}) + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b}) + default: + return nil + } +} + +func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte { + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b}) + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b}) + default: + return nil + } +} + +// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non +// consumed data or an error on failure +func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") + } + var expectedLen int + var curve Curve + switch k.Type { + case X25519PrivateKeyBanner: + expectedLen = 32 + curve = Curve_CURVE25519 + case P256PrivateKeyBanner: + expectedLen = 32 + curve = Curve_P256 + default: + return nil, r, 0, fmt.Errorf("bytes did not contain a proper private key banner") + } + if len(k.Bytes) != expectedLen { + return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve) + } + return k.Bytes, r, curve, nil +} + +func UnmarshalSigningPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") + } + var curve Curve + switch k.Type { + case EncryptedEd25519PrivateKeyBanner: + return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted + case EncryptedECDSAP256PrivateKeyBanner: + return nil, nil, Curve_P256, ErrPrivateKeyEncrypted + case Ed25519PrivateKeyBanner: + curve = Curve_CURVE25519 + if len(k.Bytes) != ed25519.PrivateKeySize { + return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize) + } + case ECDSAP256PrivateKeyBanner: + curve = Curve_P256 + if len(k.Bytes) != 32 { + return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key") + } + default: + return nil, r, 0, fmt.Errorf("bytes did not contain a proper Ed25519/ECDSA private key banner") + } + return k.Bytes, r, curve, nil +} diff --git a/cert/pem_test.go b/cert/pem_test.go new file mode 100644 index 0000000..6e49249 --- /dev/null +++ b/cert/pem_test.go @@ -0,0 +1,293 @@ +package cert + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUnmarshalCertificateFromPEM(t *testing.T) { + goodCert := []byte(` +# A good cert +-----BEGIN NEBULA CERTIFICATE----- +CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL +vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv +bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +-----END NEBULA CERTIFICATE----- +`) + badBanner := []byte(`# A bad banner +-----BEGIN NOT A NEBULA CERTIFICATE----- +CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL +vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv +bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +-----END NOT A NEBULA CERTIFICATE----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA CERTIFICATE----- +CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL +vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv +bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +-END NEBULA CERTIFICATE----`) + + certBundle := appendByteSlices(goodCert, badBanner, invalidPem) + + // Success test case + cert, rest, err := UnmarshalCertificateFromPEM(certBundle) + assert.NotNil(t, cert) + assert.Equal(t, rest, append(badBanner, invalidPem...)) + require.NoError(t, err) + + // Fail due to invalid banner. + cert, rest, err = UnmarshalCertificateFromPEM(rest) + assert.Nil(t, cert) + assert.Equal(t, rest, invalidPem) + require.EqualError(t, err, "bytes did not contain a proper certificate banner") + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + cert, rest, err = UnmarshalCertificateFromPEM(rest) + assert.Nil(t, cert) + assert.Equal(t, rest, invalidPem) + require.EqualError(t, err, "input did not contain a valid PEM encoded block") +} + +func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) { + privKey := []byte(`# A good key +-----BEGIN NEBULA ED25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-----END NEBULA ED25519 PRIVATE KEY----- +`) + privP256Key := []byte(`# A good key +-----BEGIN NEBULA ECDSA P256 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA ECDSA P256 PRIVATE KEY----- +`) + shortKey := []byte(`# A short key +-----BEGIN NEBULA ED25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA +-----END NEBULA ED25519 PRIVATE KEY----- +`) + invalidBanner := []byte(`# Invalid banner +-----BEGIN NOT A NEBULA PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-----END NOT A NEBULA PRIVATE KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA ED25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-END NEBULA ED25519 PRIVATE KEY-----`) + + keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem) + + // Success test case + k, rest, curve, err := UnmarshalSigningPrivateKeyFromPEM(keyBundle) + assert.Len(t, k, 64) + assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_CURVE25519, curve) + require.NoError(t, err) + + // Success test case + k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) + assert.Len(t, k, 32) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_P256, curve) + require.NoError(t, err) + + // Fail due to short key + k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") + + // Fail due to invalid banner + k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner") + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + require.EqualError(t, err, "input did not contain a valid PEM encoded block") +} + +func TestUnmarshalPrivateKeyFromPEM(t *testing.T) { + privKey := []byte(`# A good key +-----BEGIN NEBULA X25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA X25519 PRIVATE KEY----- +`) + privP256Key := []byte(`# A good key +-----BEGIN NEBULA P256 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA P256 PRIVATE KEY----- +`) + shortKey := []byte(`# A short key +-----BEGIN NEBULA X25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-----END NEBULA X25519 PRIVATE KEY----- +`) + invalidBanner := []byte(`# Invalid banner +-----BEGIN NOT A NEBULA PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NOT A NEBULA PRIVATE KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA X25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-END NEBULA X25519 PRIVATE KEY-----`) + + keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem) + + // Success test case + k, rest, curve, err := UnmarshalPrivateKeyFromPEM(keyBundle) + assert.Len(t, k, 32) + assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_CURVE25519, curve) + require.NoError(t, err) + + // Success test case + k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) + assert.Len(t, k, 32) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_P256, curve) + require.NoError(t, err) + + // Fail due to short key + k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") + + // Fail due to invalid banner + k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + require.EqualError(t, err, "bytes did not contain a proper private key banner") + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + require.EqualError(t, err, "input did not contain a valid PEM encoded block") +} + +func TestUnmarshalPublicKeyFromPEM(t *testing.T) { + pubKey := []byte(`# A good key +-----BEGIN NEBULA ED25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA ED25519 PUBLIC KEY----- +`) + shortKey := []byte(`# A short key +-----BEGIN NEBULA ED25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-----END NEBULA ED25519 PUBLIC KEY----- +`) + invalidBanner := []byte(`# Invalid banner +-----BEGIN NOT A NEBULA PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NOT A NEBULA PUBLIC KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA ED25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-END NEBULA ED25519 PUBLIC KEY-----`) + + keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem) + + // Success test case + k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) + assert.Len(t, k, 32) + assert.Equal(t, Curve_CURVE25519, curve) + require.NoError(t, err) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + + // Fail due to short key + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") + + // Fail due to invalid banner + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) + require.EqualError(t, err, "bytes did not contain a proper public key banner") + assert.Equal(t, rest, invalidPem) + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Equal(t, rest, invalidPem) + require.EqualError(t, err, "input did not contain a valid PEM encoded block") +} + +func TestUnmarshalX25519PublicKey(t *testing.T) { + pubKey := []byte(`# A good key +-----BEGIN NEBULA X25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA X25519 PUBLIC KEY----- +`) + pubP256Key := []byte(`# A good key +-----BEGIN NEBULA P256 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA +AAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA P256 PUBLIC KEY----- +`) + shortKey := []byte(`# A short key +-----BEGIN NEBULA X25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-----END NEBULA X25519 PUBLIC KEY----- +`) + invalidBanner := []byte(`# Invalid banner +-----BEGIN NOT A NEBULA PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NOT A NEBULA PUBLIC KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA X25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-END NEBULA X25519 PUBLIC KEY-----`) + + keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem) + + // Success test case + k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) + assert.Len(t, k, 32) + require.NoError(t, err) + assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_CURVE25519, curve) + + // Success test case + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Len(t, k, 65) + require.NoError(t, err) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_P256, curve) + + // Fail due to short key + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") + + // Fail due to invalid banner + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + require.EqualError(t, err, "bytes did not contain a proper public key banner") + assert.Equal(t, rest, invalidPem) + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + require.EqualError(t, err, "input did not contain a valid PEM encoded block") +} diff --git a/cert/sign.go b/cert/sign.go new file mode 100644 index 0000000..12d4ee4 --- /dev/null +++ b/cert/sign.go @@ -0,0 +1,167 @@ +package cert + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "fmt" + "math/big" + "net/netip" + "time" +) + +// TBSCertificate represents a certificate intended to be signed. +// It is invalid to use this structure as a Certificate. +type TBSCertificate struct { + Version Version + Name string + Networks []netip.Prefix + UnsafeNetworks []netip.Prefix + Groups []string + IsCA bool + NotBefore time.Time + NotAfter time.Time + PublicKey []byte + Curve Curve + issuer string +} + +type beingSignedCertificate interface { + // fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation + // Implementations must validate the resulting certificate contains valid information + fromTBSCertificate(*TBSCertificate) error + + // marshalForSigning returns the bytes that should be signed + marshalForSigning() ([]byte, error) + + // setSignature sets the signature for the certificate that has just been signed. The signature must not be blank. + setSignature([]byte) error +} + +type SignerLambda func(certBytes []byte) ([]byte, error) + +// Sign will create a sealed certificate using details provided by the TBSCertificate as long as those +// details do not violate constraints of the signing certificate. +// If the TBSCertificate is a CA then signer must be nil. +func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) { + switch t.Curve { + case Curve_CURVE25519: + pk := ed25519.PrivateKey(key) + sp := func(certBytes []byte) ([]byte, error) { + sig := ed25519.Sign(pk, certBytes) + return sig, nil + } + return t.SignWith(signer, curve, sp) + case Curve_P256: + pk := &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: elliptic.P256(), + }, + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 + D: new(big.Int).SetBytes(key), + } + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 + pk.X, pk.Y = pk.Curve.ScalarBaseMult(key) + sp := func(certBytes []byte) ([]byte, error) { + // We need to hash first for ECDSA + // - https://pkg.go.dev/crypto/ecdsa#SignASN1 + hashed := sha256.Sum256(certBytes) + return ecdsa.SignASN1(rand.Reader, pk, hashed[:]) + } + return t.SignWith(signer, curve, sp) + default: + return nil, fmt.Errorf("invalid curve: %s", t.Curve) + } +} + +// SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature. +// You should only use SignWith if you do not have direct access to your private key. +func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) { + if curve != t.Curve { + return nil, fmt.Errorf("curve in cert and private key supplied don't match") + } + + if signer != nil { + if t.IsCA { + return nil, fmt.Errorf("can not sign a CA certificate with another") + } + + err := checkCAConstraints(signer, t.NotBefore, t.NotAfter, t.Groups, t.Networks, t.UnsafeNetworks) + if err != nil { + return nil, err + } + + issuer, err := signer.Fingerprint() + if err != nil { + return nil, fmt.Errorf("error computing issuer: %v", err) + } + t.issuer = issuer + } else { + if !t.IsCA { + return nil, fmt.Errorf("self signed certificates must have IsCA set to true") + } + } + + var c beingSignedCertificate + switch t.Version { + case Version1: + c = &certificateV1{} + err := c.fromTBSCertificate(t) + if err != nil { + return nil, err + } + case Version2: + c = &certificateV2{} + err := c.fromTBSCertificate(t) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unknown cert version %d", t.Version) + } + + certBytes, err := c.marshalForSigning() + if err != nil { + return nil, err + } + + sig, err := sp(certBytes) + if err != nil { + return nil, err + } + + err = c.setSignature(sig) + if err != nil { + return nil, err + } + + sc, ok := c.(Certificate) + if !ok { + return nil, fmt.Errorf("invalid certificate") + } + + return sc, nil +} + +func comparePrefix(a, b netip.Prefix) int { + addr := a.Addr().Compare(b.Addr()) + if addr == 0 { + return a.Bits() - b.Bits() + } + return addr +} + +// findDuplicatePrefix returns an error if there is a duplicate prefix in the pre-sorted input slice sortedPrefixes +func findDuplicatePrefix(sortedPrefixes []netip.Prefix) error { + if len(sortedPrefixes) < 2 { + return nil + } + for i := 1; i < len(sortedPrefixes); i++ { + if comparePrefix(sortedPrefixes[i], sortedPrefixes[i-1]) == 0 { + return NewErrInvalidCertificateProperties("duplicate network detected: %v", sortedPrefixes[i]) + } + } + return nil +} diff --git a/cert/sign_test.go b/cert/sign_test.go new file mode 100644 index 0000000..e6f43cd --- /dev/null +++ b/cert/sign_test.go @@ -0,0 +1,91 @@ +package cert + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCertificateV1_Sign(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + tbs := TBSCertificate{ + Version: Version1, + Name: "testing", + Networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + UnsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/24"), + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: before, + NotAfter: after, + PublicKey: pubKey, + IsCA: false, + } + + pub, priv, err := ed25519.GenerateKey(rand.Reader) + c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv) + require.NoError(t, err) + assert.NotNil(t, c) + assert.True(t, c.CheckSignature(pub)) + + b, err := c.Marshal() + require.NoError(t, err) + uc, err := unmarshalCertificateV1(b, nil) + require.NoError(t, err) + assert.NotNil(t, uc) +} + +func TestCertificateV1_SignP256(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") + + tbs := TBSCertificate{ + Version: Version1, + Name: "testing", + Networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + UnsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: before, + NotAfter: after, + PublicKey: pubKey, + IsCA: false, + Curve: Curve_P256, + } + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) + rawPriv := priv.D.FillBytes(make([]byte, 32)) + + c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv) + require.NoError(t, err) + assert.NotNil(t, c) + assert.True(t, c.CheckSignature(pub)) + + b, err := c.Marshal() + require.NoError(t, err) + uc, err := unmarshalCertificateV1(b, nil) + require.NoError(t, err) + assert.NotNil(t, uc) +} diff --git a/cert_test/cert.go b/cert_test/cert.go new file mode 100644 index 0000000..ebc6f52 --- /dev/null +++ b/cert_test/cert.go @@ -0,0 +1,138 @@ +package cert_test + +import ( + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "io" + "net/netip" + "time" + + "github.com/slackhq/nebula/cert" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/ed25519" +) + +// NewTestCaCert will create a new ca certificate +func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { + var err error + var pub, priv []byte + + switch curve { + case cert.Curve_CURVE25519: + pub, priv, err = ed25519.GenerateKey(rand.Reader) + case cert.Curve_P256: + privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + + pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y) + priv = privk.D.FillBytes(make([]byte, 32)) + default: + // There is no default to allow the underlying lib to respond with an error + } + + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + t := &cert.TBSCertificate{ + Curve: curve, + Version: version, + Name: "test ca", + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + IsCA: true, + } + + c, err := t.Sign(nil, curve, priv) + if err != nil { + panic(err) + } + + pem, err := c.MarshalPEM() + if err != nil { + panic(err) + } + + return c, pub, priv, pem +} + +// NewTestCert will generate a signed certificate with the provided details. +// Expiry times are defaulted if you do not pass them in +func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + var pub, priv []byte + switch curve { + case cert.Curve_CURVE25519: + pub, priv = X25519Keypair() + case cert.Curve_P256: + pub, priv = P256Keypair() + default: + panic("unknown curve") + } + + nc := &cert.TBSCertificate{ + Version: v, + Curve: curve, + Name: name, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: false, + } + + c, err := nc.Sign(ca, ca.Curve(), key) + if err != nil { + panic(err) + } + + pem, err := c.MarshalPEM() + if err != nil { + panic(err) + } + + return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem +} + +func X25519Keypair() ([]byte, []byte) { + privkey := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, privkey); err != nil { + panic(err) + } + + pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) + if err != nil { + panic(err) + } + + return pubkey, privkey +} + +func P256Keypair() ([]byte, []byte) { + privkey, err := ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + pubkey := privkey.PublicKey() + return pubkey.Bytes(), privkey.Bytes() +} diff --git a/cidr/parse.go b/cidr/parse.go deleted file mode 100644 index 74367f6..0000000 --- a/cidr/parse.go +++ /dev/null @@ -1,10 +0,0 @@ -package cidr - -import "net" - -// Parse is a convenience function that returns only the IPNet -// This function ignores errors since it is primarily a test helper, the result could be nil -func Parse(s string) *net.IPNet { - _, c, _ := net.ParseCIDR(s) - return c -} diff --git a/cidr/tree4.go b/cidr/tree4.go deleted file mode 100644 index c5ebe54..0000000 --- a/cidr/tree4.go +++ /dev/null @@ -1,203 +0,0 @@ -package cidr - -import ( - "net" - - "github.com/slackhq/nebula/iputil" -) - -type Node[T any] struct { - left *Node[T] - right *Node[T] - parent *Node[T] - hasValue bool - value T -} - -type entry[T any] struct { - CIDR *net.IPNet - Value T -} - -type Tree4[T any] struct { - root *Node[T] - list []entry[T] -} - -const ( - startbit = iputil.VpnIp(0x80000000) -) - -func NewTree4[T any]() *Tree4[T] { - tree := new(Tree4[T]) - tree.root = &Node[T]{} - tree.list = []entry[T]{} - return tree -} - -func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) { - bit := startbit - node := tree.root - next := tree.root - - ip := iputil.Ip2VpnIp(cidr.IP) - mask := iputil.Ip2VpnIp(cidr.Mask) - - // Find our last ancestor in the tree - for bit&mask != 0 { - if ip&bit != 0 { - next = node.right - } else { - next = node.left - } - - if next == nil { - break - } - - bit = bit >> 1 - node = next - } - - // We already have this range so update the value - if next != nil { - addCIDR := cidr.String() - for i, v := range tree.list { - if addCIDR == v.CIDR.String() { - tree.list = append(tree.list[:i], tree.list[i+1:]...) - break - } - } - - tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) - node.value = val - node.hasValue = true - return - } - - // Build up the rest of the tree we don't already have - for bit&mask != 0 { - next = &Node[T]{} - next.parent = node - - if ip&bit != 0 { - node.right = next - } else { - node.left = next - } - - bit >>= 1 - node = next - } - - // Final node marks our cidr, set the value - node.value = val - node.hasValue = true - tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) -} - -// Contains finds the first match, which may be the least specific -func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) { - bit := startbit - node := tree.root - - for node != nil { - if node.hasValue { - return true, node.value - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - - } - - return false, value -} - -// MostSpecificContains finds the most specific match -func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) { - bit := startbit - node := tree.root - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - return ok, value -} - -type eachFunc[T any] func(T) bool - -// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete -// The final return value will be true if the provided function returned true -func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool { - bit := startbit - node := tree.root - - for node != nil { - if node.hasValue { - // If the each func returns true then we can exit the loop - if each(node.value) { - return true - } - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - return false -} - -// GetCIDR returns the entry added by the most recent matching AddCIDR call -func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) { - bit := startbit - node := tree.root - - ip := iputil.Ip2VpnIp(cidr.IP) - mask := iputil.Ip2VpnIp(cidr.Mask) - - // Find our last ancestor in the tree - for node != nil && bit&mask != 0 { - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit = bit >> 1 - } - - if bit&mask == 0 && node != nil { - value = node.value - ok = node.hasValue - } - - return ok, value -} - -// List will return all CIDRs and their current values. Do not modify the contents! -func (tree *Tree4[T]) List() []entry[T] { - return tree.list -} diff --git a/cidr/tree4_test.go b/cidr/tree4_test.go deleted file mode 100644 index cd17be4..0000000 --- a/cidr/tree4_test.go +++ /dev/null @@ -1,170 +0,0 @@ -package cidr - -import ( - "net" - "testing" - - "github.com/slackhq/nebula/iputil" - "github.com/stretchr/testify/assert" -) - -func TestCIDRTree_List(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/16"), "1") - tree.AddCIDR(Parse("1.0.0.0/8"), "2") - tree.AddCIDR(Parse("1.0.0.0/16"), "3") - tree.AddCIDR(Parse("1.0.0.0/16"), "4") - list := tree.List() - assert.Len(t, list, 2) - assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String()) - assert.Equal(t, "2", list[0].Value) - assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String()) - assert.Equal(t, "4", list[1].Value) -} - -func TestCIDRTree_Contains(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.0/24"), "4a") - tree.AddCIDR(Parse("4.1.1.1/32"), "4b") - tree.AddCIDR(Parse("4.1.2.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "1", "1.0.0.0"}, - {true, "1", "1.255.255.255"}, - {true, "2", "2.1.0.0"}, - {true, "2", "2.1.255.255"}, - {true, "3", "3.1.1.0"}, - {true, "3", "3.1.1.255"}, - {true, "4a", "4.1.1.255"}, - {true, "4a", "4.1.1.1"}, - {true, "5", "240.0.0.0"}, - {true, "5", "255.255.255.255"}, - {false, "", "239.0.0.0"}, - {false, "", "4.1.2.2"}, - } - - for _, tt := range tests { - ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } - - tree = NewTree4[string]() - tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) -} - -func TestCIDRTree_MostSpecificContains(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.0/24"), "4a") - tree.AddCIDR(Parse("4.1.1.0/30"), "4b") - tree.AddCIDR(Parse("4.1.1.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "1", "1.0.0.0"}, - {true, "1", "1.255.255.255"}, - {true, "2", "2.1.0.0"}, - {true, "2", "2.1.255.255"}, - {true, "3", "3.1.1.0"}, - {true, "3", "3.1.1.255"}, - {true, "4a", "4.1.1.255"}, - {true, "4b", "4.1.1.2"}, - {true, "4c", "4.1.1.1"}, - {true, "5", "240.0.0.0"}, - {true, "5", "255.255.255.255"}, - {false, "", "239.0.0.0"}, - {false, "", "4.1.2.2"}, - } - - for _, tt := range tests { - ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } - - tree = NewTree4[string]() - tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) -} - -func TestTree4_GetCIDR(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.0/24"), "4a") - tree.AddCIDR(Parse("4.1.1.1/32"), "4b") - tree.AddCIDR(Parse("4.1.2.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - - tests := []struct { - Found bool - Result interface{} - IPNet *net.IPNet - }{ - {true, "1", Parse("1.0.0.0/8")}, - {true, "2", Parse("2.1.0.0/16")}, - {true, "3", Parse("3.1.1.0/24")}, - {true, "4a", Parse("4.1.1.0/24")}, - {true, "4b", Parse("4.1.1.1/32")}, - {true, "4c", Parse("4.1.2.1/32")}, - {true, "5", Parse("254.0.0.0/4")}, - {false, "", Parse("2.0.0.0/8")}, - } - - for _, tt := range tests { - ok, r := tree.GetCIDR(tt.IPNet) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } -} - -func BenchmarkCIDRTree_Contains(b *testing.B) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.1.0.0/16"), "1") - tree.AddCIDR(Parse("1.2.1.1/32"), "1") - tree.AddCIDR(Parse("192.2.1.1/32"), "1") - tree.AddCIDR(Parse("172.2.1.1/32"), "1") - - ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1")) - b.Run("found", func(b *testing.B) { - for i := 0; i < b.N; i++ { - tree.Contains(ip) - } - }) - - ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255")) - b.Run("not found", func(b *testing.B) { - for i := 0; i < b.N; i++ { - tree.Contains(ip) - } - }) -} diff --git a/cidr/tree6.go b/cidr/tree6.go deleted file mode 100644 index 3f2cd2a..0000000 --- a/cidr/tree6.go +++ /dev/null @@ -1,189 +0,0 @@ -package cidr - -import ( - "net" - - "github.com/slackhq/nebula/iputil" -) - -const startbit6 = uint64(1 << 63) - -type Tree6[T any] struct { - root4 *Node[T] - root6 *Node[T] -} - -func NewTree6[T any]() *Tree6[T] { - tree := new(Tree6[T]) - tree.root4 = &Node[T]{} - tree.root6 = &Node[T]{} - return tree -} - -func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) { - var node, next *Node[T] - - cidrIP, ipv4 := isIPV4(cidr.IP) - if ipv4 { - node = tree.root4 - next = tree.root4 - - } else { - node = tree.root6 - next = tree.root6 - } - - for i := 0; i < len(cidrIP); i += 4 { - ip := iputil.Ip2VpnIp(cidrIP[i : i+4]) - mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4]) - bit := startbit - - // Find our last ancestor in the tree - for bit&mask != 0 { - if ip&bit != 0 { - next = node.right - } else { - next = node.left - } - - if next == nil { - break - } - - bit = bit >> 1 - node = next - } - - // Build up the rest of the tree we don't already have - for bit&mask != 0 { - next = &Node[T]{} - next.parent = node - - if ip&bit != 0 { - node.right = next - } else { - node.left = next - } - - bit >>= 1 - node = next - } - } - - // Final node marks our cidr, set the value - node.value = val - node.hasValue = true -} - -// Finds the most specific match -func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) { - var node *Node[T] - - wholeIP, ipv4 := isIPV4(ip) - if ipv4 { - node = tree.root4 - } else { - node = tree.root6 - } - - for i := 0; i < len(wholeIP); i += 4 { - ip := iputil.Ip2VpnIp(wholeIP[i : i+4]) - bit := startbit - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if bit == 0 { - break - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - } - - return ok, value -} - -func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) { - bit := startbit - node := tree.root4 - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - return ok, value -} - -func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) { - ip := hi - node := tree.root6 - - for i := 0; i < 2; i++ { - bit := startbit6 - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if bit == 0 { - break - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - ip = lo - } - - return ok, value -} - -func isIPV4(ip net.IP) (net.IP, bool) { - if len(ip) == net.IPv4len { - return ip, true - } - - if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff { - return ip[12:16], true - } - - return ip, false -} - -func isZeros(p net.IP) bool { - for i := 0; i < len(p); i++ { - if p[i] != 0 { - return false - } - } - return true -} diff --git a/cidr/tree6_test.go b/cidr/tree6_test.go deleted file mode 100644 index eb159ec..0000000 --- a/cidr/tree6_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package cidr - -import ( - "encoding/binary" - "net" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCIDR6Tree_MostSpecificContains(t *testing.T) { - tree := NewTree6[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.1/24"), "4a") - tree.AddCIDR(Parse("4.1.1.1/30"), "4b") - tree.AddCIDR(Parse("4.1.1.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "1", "1.0.0.0"}, - {true, "1", "1.255.255.255"}, - {true, "2", "2.1.0.0"}, - {true, "2", "2.1.255.255"}, - {true, "3", "3.1.1.0"}, - {true, "3", "3.1.1.255"}, - {true, "4a", "4.1.1.255"}, - {true, "4b", "4.1.1.2"}, - {true, "4c", "4.1.1.1"}, - {true, "5", "240.0.0.0"}, - {true, "5", "255.255.255.255"}, - {true, "6a", "1:2:0:4:1:1:1:1"}, - {true, "6b", "1:2:0:4:5:1:1:1"}, - {true, "6c", "1:2:0:4:5:0:0:0"}, - {false, "", "239.0.0.0"}, - {false, "", "4.1.2.2"}, - } - - for _, tt := range tests { - ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP)) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } - - tree = NewTree6[string]() - tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - tree.AddCIDR(Parse("::/0"), "cool6") - ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0")) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255")) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.MostSpecificContains(net.ParseIP("::")) - assert.True(t, ok) - assert.Equal(t, "cool6", r) - - ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8")) - assert.True(t, ok) - assert.Equal(t, "cool6", r) -} - -func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) { - tree := NewTree6[string]() - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "6a", "1:2:0:4:1:1:1:1"}, - {true, "6b", "1:2:0:4:5:1:1:1"}, - {true, "6c", "1:2:0:4:5:0:0:0"}, - } - - for _, tt := range tests { - ip := net.ParseIP(tt.IP) - hi := binary.BigEndian.Uint64(ip[:8]) - lo := binary.BigEndian.Uint64(ip[8:]) - - ok, r := tree.MostSpecificContainsIpV6(hi, lo) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } -} diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index 4e5d51d..f83c94f 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -8,13 +8,14 @@ import ( "fmt" "io" "math" - "net" + "net/netip" "os" "strings" "time" "github.com/skip2/go-qrcode" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/pkclient" "golang.org/x/crypto/ed25519" ) @@ -26,32 +27,43 @@ type caFlags struct { outCertPath *string outQRPath *string groups *string - ips *string - subnets *string + networks *string + unsafeNetworks *string argonMemory *uint argonIterations *uint argonParallelism *uint encryption *bool + version *uint - curve *string + curve *string + p11url *string + + // Deprecated options + ips *string + subnets *string } func newCaFlags() *caFlags { cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)} cf.set.Usage = func() {} cf.name = cf.set.String("name", "", "Required: name of the certificate authority") + cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use") cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to") cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to") cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use") - cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses") - cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets") + cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks") + cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks") cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase") cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase") cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase") cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format") cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)") + cf.p11url = p11Flag(cf.set) + + cf.ips = cf.set.String("ips", "", "Deprecated, see -networks") + cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks") return &cf } @@ -76,17 +88,21 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error return err } + isP11 := len(*cf.p11url) > 0 + if err := mustFlagString("name", cf.name); err != nil { return err } - if err := mustFlagString("out-key", cf.outKeyPath); err != nil { - return err + if !isP11 { + if err = mustFlagString("out-key", cf.outKeyPath); err != nil { + return err + } } if err := mustFlagString("out-crt", cf.outCertPath); err != nil { return err } var kdfParams *cert.Argon2Parameters - if *cf.encryption { + if !isP11 && *cf.encryption { if kdfParams, err = parseArgonParameters(*cf.argonMemory, *cf.argonParallelism, *cf.argonIterations); err != nil { return err } @@ -106,44 +122,57 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error } } - var ips []*net.IPNet - if *cf.ips != "" { - for _, rs := range strings.Split(*cf.ips, ",") { + version := cert.Version(*cf.version) + if version != cert.Version1 && version != cert.Version2 { + return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) + } + + var networks []netip.Prefix + if *cf.networks == "" && *cf.ips != "" { + // Pull up deprecated -ips flag if needed + *cf.networks = *cf.ips + } + + if *cf.networks != "" { + for _, rs := range strings.Split(*cf.networks, ",") { rs := strings.Trim(rs, " ") if rs != "" { - ip, ipNet, err := net.ParseCIDR(rs) + n, err := netip.ParsePrefix(rs) if err != nil { - return newHelpErrorf("invalid ip definition: %s", err) + return newHelpErrorf("invalid -networks definition: %s", rs) } - if ip.To4() == nil { - return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs) + if version == cert.Version1 && !n.Addr().Is4() { + return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs) } - - ipNet.IP = ip - ips = append(ips, ipNet) + networks = append(networks, n) } } } - var subnets []*net.IPNet - if *cf.subnets != "" { - for _, rs := range strings.Split(*cf.subnets, ",") { + var unsafeNetworks []netip.Prefix + if *cf.unsafeNetworks == "" && *cf.subnets != "" { + // Pull up deprecated -subnets flag if needed + *cf.unsafeNetworks = *cf.subnets + } + + if *cf.unsafeNetworks != "" { + for _, rs := range strings.Split(*cf.unsafeNetworks, ",") { rs := strings.Trim(rs, " ") if rs != "" { - _, s, err := net.ParseCIDR(rs) + n, err := netip.ParsePrefix(rs) if err != nil { - return newHelpErrorf("invalid subnet definition: %s", err) + return newHelpErrorf("invalid -unsafe-networks definition: %s", rs) } - if s.IP.To4() == nil { - return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs) + if version == cert.Version1 && !n.Addr().Is4() { + return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs) } - subnets = append(subnets, s) + unsafeNetworks = append(unsafeNetworks, n) } } } var passphrase []byte - if *cf.encryption { + if !isP11 && *cf.encryption { for i := 0; i < 5; i++ { out.Write([]byte("Enter passphrase: ")) passphrase, err = pr.ReadPassword() @@ -166,74 +195,109 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error var curve cert.Curve var pub, rawPriv []byte - switch *cf.curve { - case "25519", "X25519", "Curve25519", "CURVE25519": - curve = cert.Curve_CURVE25519 - pub, rawPriv, err = ed25519.GenerateKey(rand.Reader) - if err != nil { - return fmt.Errorf("error while generating ed25519 keys: %s", err) - } - case "P256": - var key *ecdsa.PrivateKey - curve = cert.Curve_P256 - key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return fmt.Errorf("error while generating ecdsa keys: %s", err) + var p11Client *pkclient.PKClient + + if isP11 { + switch *cf.curve { + case "P256": + curve = cert.Curve_P256 + default: + return fmt.Errorf("invalid curve for PKCS#11: %s", *cf.curve) } - // ecdh.PrivateKey lets us get at the encoded bytes, even though - // we aren't using ECDH here. - eKey, err := key.ECDH() + p11Client, err = pkclient.FromUrl(*cf.p11url) if err != nil { - return fmt.Errorf("error while converting ecdsa key: %s", err) + return fmt.Errorf("error while creating PKCS#11 client: %w", err) + } + defer func(client *pkclient.PKClient) { + _ = client.Close() + }(p11Client) + pub, err = p11Client.GetPubKey() + if err != nil { + return fmt.Errorf("error while getting public key with PKCS#11: %w", err) + } + } else { + switch *cf.curve { + case "25519", "X25519", "Curve25519", "CURVE25519": + curve = cert.Curve_CURVE25519 + pub, rawPriv, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return fmt.Errorf("error while generating ed25519 keys: %s", err) + } + case "P256": + var key *ecdsa.PrivateKey + curve = cert.Curve_P256 + key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return fmt.Errorf("error while generating ecdsa keys: %s", err) + } + + // ecdh.PrivateKey lets us get at the encoded bytes, even though + // we aren't using ECDH here. + eKey, err := key.ECDH() + if err != nil { + return fmt.Errorf("error while converting ecdsa key: %s", err) + } + rawPriv = eKey.Bytes() + pub = eKey.PublicKey().Bytes() + default: + return fmt.Errorf("invalid curve: %s", *cf.curve) } - rawPriv = eKey.Bytes() - pub = eKey.PublicKey().Bytes() } - nc := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: *cf.name, - Groups: groups, - Ips: ips, - Subnets: subnets, - NotBefore: time.Now(), - NotAfter: time.Now().Add(*cf.duration), - PublicKey: pub, - IsCA: true, - Curve: curve, - }, + t := &cert.TBSCertificate{ + Version: version, + Name: *cf.name, + Groups: groups, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + NotBefore: time.Now(), + NotAfter: time.Now().Add(*cf.duration), + PublicKey: pub, + IsCA: true, + Curve: curve, } - if _, err := os.Stat(*cf.outKeyPath); err == nil { - return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath) + if !isP11 { + if _, err := os.Stat(*cf.outKeyPath); err == nil { + return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath) + } } if _, err := os.Stat(*cf.outCertPath); err == nil { return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath) } - err = nc.Sign(curve, rawPriv) - if err != nil { - return fmt.Errorf("error while signing: %s", err) - } - + var c cert.Certificate var b []byte - if *cf.encryption { - b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams) + + if isP11 { + c, err = t.SignWith(nil, curve, p11Client.SignASN1) if err != nil { - return fmt.Errorf("error while encrypting out-key: %s", err) + return fmt.Errorf("error while signing with PKCS#11: %w", err) } } else { - b = cert.MarshalSigningPrivateKey(curve, rawPriv) + c, err = t.Sign(nil, curve, rawPriv) + if err != nil { + return fmt.Errorf("error while signing: %s", err) + } + + if *cf.encryption { + b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams) + if err != nil { + return fmt.Errorf("error while encrypting out-key: %s", err) + } + } else { + b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv) + } + + err = os.WriteFile(*cf.outKeyPath, b, 0600) + if err != nil { + return fmt.Errorf("error while writing out-key: %s", err) + } } - err = os.WriteFile(*cf.outKeyPath, b, 0600) - if err != nil { - return fmt.Errorf("error while writing out-key: %s", err) - } - - b, err = nc.MarshalToPEM() + b, err = c.MarshalPEM() if err != nil { return fmt.Errorf("error while marshalling certificate: %s", err) } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 3a53405..b1cbde9 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -14,10 +14,9 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -//TODO: test file permissions - func Test_caSummary(t *testing.T) { assert.Equal(t, "ca : create a self signed certificate authority", caSummary()) } @@ -43,17 +42,24 @@ func Test_caHelp(t *testing.T) { " -groups string\n"+ " \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+ " -ips string\n"+ - " \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+ + " Deprecated, see -networks\n"+ " -name string\n"+ " \tRequired: name of the certificate authority\n"+ + " -networks string\n"+ + " \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+ " -out-crt string\n"+ " \tOptional: path to write the certificate to (default \"ca.crt\")\n"+ " -out-key string\n"+ " \tOptional: path to write the private key to (default \"ca.key\")\n"+ " -out-qr string\n"+ " \tOptional: output a qr code image (png) of the certificate\n"+ + optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ " -subnets string\n"+ - " \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n", + " \tDeprecated, see -unsafe-networks\n"+ + " -unsafe-networks string\n"+ + " \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+ + " -version uint\n"+ + " \tOptional: version of the certificate format to use (default 2)\n", ob.String(), ) } @@ -82,93 +88,94 @@ func Test_ca(t *testing.T) { // required args assertHelpError(t, ca( - []string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, + []string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, ), "-name is required") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // ipv4 only ips - assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // ipv4 only subnets - assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // failed key write ob.Reset() eb.Reset() - args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} - assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} + require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.Nil(t, err) - os.Remove(keyF.Name()) + require.NoError(t, err) + require.NoError(t, os.Remove(keyF.Name())) // failed cert write ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} + require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // create temp cert file crtF, err := os.CreateTemp("", "test.crt") - assert.Nil(t, err) - os.Remove(crtF.Name()) - os.Remove(keyF.Name()) + require.NoError(t, err) + require.NoError(t, os.Remove(crtF.Name())) + require.NoError(t, os.Remove(keyF.Name())) // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, ca(args, ob, eb, nopw)) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + require.NoError(t, ca(args, ob, eb, nopw)) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // read cert and key files rb, _ := os.ReadFile(keyF.Name()) - lKey, b, err := cert.UnmarshalEd25519PrivateKey(rb) - assert.Len(t, b, 0) - assert.Nil(t, err) + lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb) + assert.Equal(t, cert.Curve_CURVE25519, c) + assert.Empty(t, b) + require.NoError(t, err) assert.Len(t, lKey, 64) rb, _ = os.ReadFile(crtF.Name()) - lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb) - assert.Len(t, b, 0) - assert.Nil(t, err) + lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) + assert.Empty(t, b) + require.NoError(t, err) - assert.Equal(t, "test", lCrt.Details.Name) - assert.Len(t, lCrt.Details.Ips, 0) - assert.True(t, lCrt.Details.IsCA) - assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups) - assert.Len(t, lCrt.Details.Subnets, 0) - assert.Len(t, lCrt.Details.PublicKey, 32) - assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore)) - assert.Equal(t, "", lCrt.Details.Issuer) - assert.True(t, lCrt.CheckSignature(lCrt.Details.PublicKey)) + assert.Equal(t, "test", lCrt.Name()) + assert.Empty(t, lCrt.Networks()) + assert.True(t, lCrt.IsCA()) + assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups()) + assert.Empty(t, lCrt.UnsafeNetworks()) + assert.Len(t, lCrt.PublicKey(), 32) + assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore())) + assert.Empty(t, lCrt.Issuer()) + assert.True(t, lCrt.CheckSignature(lCrt.PublicKey())) // test encrypted key os.Remove(keyF.Name()) os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, ca(args, ob, eb, testpw)) + args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + require.NoError(t, ca(args, ob, eb, testpw)) assert.Equal(t, pwPromptOb, ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, eb.String()) // read encrypted key file and verify default params rb, _ = os.ReadFile(keyF.Name()) k, _ := pem.Decode(rb) ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes) - assert.Nil(t, err) + require.NoError(t, err) // we won't know salt in advance, so just check start of string assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory) assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism) @@ -178,8 +185,8 @@ func Test_ca(t *testing.T) { var curve cert.Curve curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb) assert.Equal(t, cert.Curve_CURVE25519, curve) - assert.Nil(t, err) - assert.Len(t, b, 0) + require.NoError(t, err) + assert.Empty(t, b) assert.Len(t, lKey, 64) // test when reading passsword results in an error @@ -187,45 +194,45 @@ func Test_ca(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Error(t, ca(args, ob, eb, errpw)) + args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + require.Error(t, ca(args, ob, eb, errpw)) assert.Equal(t, pwPromptOb, ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, eb.String()) // test when user fails to enter a password os.Remove(keyF.Name()) os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") + args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up - assert.Equal(t, "", eb.String()) + assert.Empty(t, eb.String()) // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, ca(args, ob, eb, nopw)) + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + require.NoError(t, ca(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // test that we won't overwrite existing key file os.Remove(keyF.Name()) ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) os.Remove(keyF.Name()) } diff --git a/cmd/nebula-cert/keygen.go b/cmd/nebula-cert/keygen.go index d94cbf1..496f84c 100644 --- a/cmd/nebula-cert/keygen.go +++ b/cmd/nebula-cert/keygen.go @@ -6,6 +6,8 @@ import ( "io" "os" + "github.com/slackhq/nebula/pkclient" + "github.com/slackhq/nebula/cert" ) @@ -13,8 +15,8 @@ type keygenFlags struct { set *flag.FlagSet outKeyPath *string outPubPath *string - - curve *string + curve *string + p11url *string } func newKeygenFlags() *keygenFlags { @@ -23,6 +25,7 @@ func newKeygenFlags() *keygenFlags { cf.outPubPath = cf.set.String("out-pub", "", "Required: path to write the public key to") cf.outKeyPath = cf.set.String("out-key", "", "Required: path to write the private key to") cf.curve = cf.set.String("curve", "25519", "ECDH Curve (25519, P256)") + cf.p11url = p11Flag(cf.set) return &cf } @@ -33,32 +36,58 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error { return err } - if err := mustFlagString("out-key", cf.outKeyPath); err != nil { - return err + isP11 := len(*cf.p11url) > 0 + + if !isP11 { + if err = mustFlagString("out-key", cf.outKeyPath); err != nil { + return err + } } - if err := mustFlagString("out-pub", cf.outPubPath); err != nil { + if err = mustFlagString("out-pub", cf.outPubPath); err != nil { return err } var pub, rawPriv []byte var curve cert.Curve - switch *cf.curve { - case "25519", "X25519", "Curve25519", "CURVE25519": - pub, rawPriv = x25519Keypair() - curve = cert.Curve_CURVE25519 - case "P256": - pub, rawPriv = p256Keypair() - curve = cert.Curve_P256 - default: - return fmt.Errorf("invalid curve: %s", *cf.curve) + if isP11 { + switch *cf.curve { + case "P256": + curve = cert.Curve_P256 + default: + return fmt.Errorf("invalid curve for PKCS#11: %s", *cf.curve) + } + } else { + switch *cf.curve { + case "25519", "X25519", "Curve25519", "CURVE25519": + pub, rawPriv = x25519Keypair() + curve = cert.Curve_CURVE25519 + case "P256": + pub, rawPriv = p256Keypair() + curve = cert.Curve_P256 + default: + return fmt.Errorf("invalid curve: %s", *cf.curve) + } } - err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600) - if err != nil { - return fmt.Errorf("error while writing out-key: %s", err) + if isP11 { + p11Client, err := pkclient.FromUrl(*cf.p11url) + if err != nil { + return fmt.Errorf("error while creating PKCS#11 client: %w", err) + } + defer func(client *pkclient.PKClient) { + _ = client.Close() + }(p11Client) + pub, err = p11Client.GetPubKey() + if err != nil { + return fmt.Errorf("error while getting public key: %w", err) + } + } else { + err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600) + if err != nil { + return fmt.Errorf("error while writing out-key: %s", err) + } } - - err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600) + err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600) if err != nil { return fmt.Errorf("error while writing out-pub: %s", err) } @@ -72,7 +101,7 @@ func keygenSummary() string { func keygenHelp(out io.Writer) { cf := newKeygenFlags() - out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n")) + _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n")) cf.set.SetOutput(out) cf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go index 9a3b3f3..95d9893 100644 --- a/cmd/nebula-cert/keygen_test.go +++ b/cmd/nebula-cert/keygen_test.go @@ -7,10 +7,9 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -//TODO: test file permissions - func Test_keygenSummary(t *testing.T) { assert.Equal(t, "keygen : create a public/private key pair. the public key can be passed to `nebula-cert sign`", keygenSummary()) } @@ -26,7 +25,8 @@ func Test_keygenHelp(t *testing.T) { " -out-key string\n"+ " \tRequired: path to write the private key to\n"+ " -out-pub string\n"+ - " \tRequired: path to write the public key to\n", + " \tRequired: path to write the public key to\n"+ + optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n"), ob.String(), ) } @@ -37,57 +37,59 @@ func Test_keygen(t *testing.T) { // required args assertHelpError(t, keygen([]string{"-out-pub", "nope"}, ob, eb), "-out-key is required") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) assertHelpError(t, keygen([]string{"-out-key", "nope"}, ob, eb), "-out-pub is required") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // failed key write ob.Reset() eb.Reset() args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"} - assert.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.Nil(t, err) + require.NoError(t, err) defer os.Remove(keyF.Name()) // failed pub write ob.Reset() eb.Reset() args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()} - assert.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // create temp pub file pubF, err := os.CreateTemp("", "test.pub") - assert.Nil(t, err) + require.NoError(t, err) defer os.Remove(pubF.Name()) // test proper keygen ob.Reset() eb.Reset() args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, keygen(args, ob, eb)) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + require.NoError(t, keygen(args, ob, eb)) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // read cert and key files rb, _ := os.ReadFile(keyF.Name()) - lKey, b, err := cert.UnmarshalX25519PrivateKey(rb) - assert.Len(t, b, 0) - assert.Nil(t, err) + lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) + assert.Equal(t, cert.Curve_CURVE25519, curve) + assert.Empty(t, b) + require.NoError(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(pubF.Name()) - lPub, b, err := cert.UnmarshalX25519PublicKey(rb) - assert.Len(t, b, 0) - assert.Nil(t, err) + lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb) + assert.Equal(t, cert.Curve_CURVE25519, curve) + assert.Empty(t, b) + require.NoError(t, err) assert.Len(t, lPub, 32) } diff --git a/cmd/nebula-cert/main.go b/cmd/nebula-cert/main.go index b803d30..c88626f 100644 --- a/cmd/nebula-cert/main.go +++ b/cmd/nebula-cert/main.go @@ -17,7 +17,7 @@ func (he *helpError) Error() string { return he.s } -func newHelpErrorf(s string, v ...interface{}) error { +func newHelpErrorf(s string, v ...any) error { return &helpError{s: fmt.Sprintf(s, v...)} } diff --git a/cmd/nebula-cert/main_test.go b/cmd/nebula-cert/main_test.go index 3d0fa1b..2e92e7e 100644 --- a/cmd/nebula-cert/main_test.go +++ b/cmd/nebula-cert/main_test.go @@ -3,15 +3,15 @@ package main import ( "bytes" "errors" + "fmt" "io" "os" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -//TODO: all flag parsing continueOnError will print to stderr on its own currently - func Test_help(t *testing.T) { expected := "Usage of " + os.Args[0] + " :\n" + " Global flags:\n" + @@ -77,8 +77,16 @@ func assertHelpError(t *testing.T, err error, msg string) { case *helpError: // good default: - t.Fatal("err was not a helpError") + t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg)) } - assert.EqualError(t, err, msg) + require.EqualError(t, err, msg) +} + +func optionalPkcs11String(msg string) string { + if p11Supported() { + return msg + } else { + return "" + } } diff --git a/cmd/nebula-cert/p11_cgo.go b/cmd/nebula-cert/p11_cgo.go new file mode 100644 index 0000000..f1f1ec6 --- /dev/null +++ b/cmd/nebula-cert/p11_cgo.go @@ -0,0 +1,15 @@ +//go:build cgo && pkcs11 + +package main + +import ( + "flag" +) + +func p11Supported() bool { + return true +} + +func p11Flag(set *flag.FlagSet) *string { + return set.String("pkcs11", "", "Optional: PKCS#11 URI to an existing private key") +} diff --git a/cmd/nebula-cert/p11_stub.go b/cmd/nebula-cert/p11_stub.go new file mode 100644 index 0000000..5afeaea --- /dev/null +++ b/cmd/nebula-cert/p11_stub.go @@ -0,0 +1,16 @@ +//go:build !cgo || !pkcs11 + +package main + +import ( + "flag" +) + +func p11Supported() bool { + return false +} + +func p11Flag(set *flag.FlagSet) *string { + var ret = "" + return &ret +} diff --git a/cmd/nebula-cert/print.go b/cmd/nebula-cert/print.go index 746d6a3..30e0965 100644 --- a/cmd/nebula-cert/print.go +++ b/cmd/nebula-cert/print.go @@ -45,28 +45,27 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("unable to read cert; %s", err) } - var c *cert.NebulaCertificate + var c cert.Certificate var qrBytes []byte part := 0 + var jsonCerts []cert.Certificate + for { - c, rawCert, err = cert.UnmarshalNebulaCertificateFromPEM(rawCert) + c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert) if err != nil { return fmt.Errorf("error while unmarshaling cert: %s", err) } if *pf.json { - b, _ := json.Marshal(c) - out.Write(b) - out.Write([]byte("\n")) - + jsonCerts = append(jsonCerts, c) } else { - out.Write([]byte(c.String())) - out.Write([]byte("\n")) + _, _ = out.Write([]byte(c.String())) + _, _ = out.Write([]byte("\n")) } if *pf.outQRPath != "" { - b, err := c.MarshalToPEM() + b, err := c.MarshalPEM() if err != nil { return fmt.Errorf("error while marshalling cert to PEM: %s", err) } @@ -80,6 +79,12 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { part++ } + if *pf.json { + b, _ := json.Marshal(jsonCerts) + _, _ = out.Write(b) + _, _ = out.Write([]byte("\n")) + } + if *pf.outQRPath != "" { b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5) if err != nil { diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index 9fa8a54..221ab77 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -2,12 +2,17 @@ package main import ( "bytes" + "crypto/ed25519" + "crypto/rand" + "encoding/hex" + "net/netip" "os" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_printSummary(t *testing.T) { @@ -38,84 +43,203 @@ func Test_printCert(t *testing.T) { // no path err := printCert([]string{}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) assertHelpError(t, err, "-path is required") // no cert at path ob.Reset() eb.Reset() err = printCert([]string{"-path", "does_not_exist"}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError) // invalid cert at path ob.Reset() eb.Reset() tf, err := os.CreateTemp("", "print-cert") - assert.Nil(t, err) + require.NoError(t, err) defer os.Remove(tf.Name()) tf.WriteString("-----BEGIN NOPE-----") err = printCert([]string{"-path", tf.Name()}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block") // test multiple certs ob.Reset() eb.Reset() tf.Truncate(0) tf.Seek(0, 0) - c := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test", - Groups: []string{"hi"}, - PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, - }, - Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, - } + ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil) + c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}, nil, []string{"hi"}) - p, _ := c.MarshalToPEM() + p, _ := c.MarshalPEM() tf.Write(p) tf.Write(p) tf.Write(p) err = printCert([]string{"-path", tf.Name()}, ob, eb) - assert.Nil(t, err) + fp, _ := c.Fingerprint() + pk := hex.EncodeToString(c.PublicKey()) + sig := hex.EncodeToString(c.Signature()) + require.NoError(t, err) assert.Equal( t, - "NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\n", + //"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", + `{ + "details": { + "curve": "CURVE25519", + "groups": [ + "hi" + ], + "isCa": false, + "issuer": "`+c.Issuer()+`", + "name": "test", + "networks": [ + "10.0.0.123/8" + ], + "notAfter": "0001-01-01T00:00:00Z", + "notBefore": "0001-01-01T00:00:00Z", + "publicKey": "`+pk+`", + "unsafeNetworks": [] + }, + "fingerprint": "`+fp+`", + "signature": "`+sig+`", + "version": 1 +} +{ + "details": { + "curve": "CURVE25519", + "groups": [ + "hi" + ], + "isCa": false, + "issuer": "`+c.Issuer()+`", + "name": "test", + "networks": [ + "10.0.0.123/8" + ], + "notAfter": "0001-01-01T00:00:00Z", + "notBefore": "0001-01-01T00:00:00Z", + "publicKey": "`+pk+`", + "unsafeNetworks": [] + }, + "fingerprint": "`+fp+`", + "signature": "`+sig+`", + "version": 1 +} +{ + "details": { + "curve": "CURVE25519", + "groups": [ + "hi" + ], + "isCa": false, + "issuer": "`+c.Issuer()+`", + "name": "test", + "networks": [ + "10.0.0.123/8" + ], + "notAfter": "0001-01-01T00:00:00Z", + "notBefore": "0001-01-01T00:00:00Z", + "publicKey": "`+pk+`", + "unsafeNetworks": [] + }, + "fingerprint": "`+fp+`", + "signature": "`+sig+`", + "version": 1 +} +`, ob.String(), ) - assert.Equal(t, "", eb.String()) + assert.Empty(t, eb.String()) // test json ob.Reset() eb.Reset() tf.Truncate(0) tf.Seek(0, 0) - c = cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test", - Groups: []string{"hi"}, - PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, - }, - Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, - } - - p, _ = c.MarshalToPEM() tf.Write(p) tf.Write(p) tf.Write(p) err = printCert([]string{"-json", "-path", tf.Name()}, ob, eb) - assert.Nil(t, err) + fp, _ = c.Fingerprint() + pk = hex.EncodeToString(c.PublicKey()) + sig = hex.EncodeToString(c.Signature()) + require.NoError(t, err) assert.Equal( t, - "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n", + `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}] +`, ob.String(), ) - assert.Equal(t, "", eb.String()) + assert.Empty(t, eb.String()) +} + +// NewTestCaCert will generate a CA cert +func NewTestCaCert(name string, pubKey, privKey []byte, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) { + var err error + if pubKey == nil || privKey == nil { + pubKey, privKey, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + } + + t := &cert.TBSCertificate{ + Version: cert.Version1, + Name: name, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pubKey, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + IsCA: true, + } + + c, err := t.Sign(nil, cert.Curve_CURVE25519, privKey) + if err != nil { + panic(err) + } + + return c, privKey +} + +func NewTestCert(ca cert.Certificate, signerKey []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) { + if before.IsZero() { + before = ca.NotBefore() + } + + if after.IsZero() { + after = ca.NotAfter() + } + + if len(networks) == 0 { + networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")} + } + + pub, rawPriv := x25519Keypair() + nc := &cert.TBSCertificate{ + Version: cert.Version1, + Name: name, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: false, + } + + c, err := nc.Sign(ca, ca.Curve(), signerKey) + if err != nil { + panic(err) + } + + return c, rawPriv } diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 35d6446..ebcb592 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -3,50 +3,63 @@ package main import ( "crypto/ecdh" "crypto/rand" + "errors" "flag" "fmt" "io" - "net" + "net/netip" "os" "strings" "time" "github.com/skip2/go-qrcode" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/pkclient" "golang.org/x/crypto/curve25519" ) type signFlags struct { - set *flag.FlagSet - caKeyPath *string - caCertPath *string - name *string - ip *string - duration *time.Duration - inPubPath *string - outKeyPath *string - outCertPath *string - outQRPath *string - groups *string - subnets *string + set *flag.FlagSet + version *uint + caKeyPath *string + caCertPath *string + name *string + networks *string + unsafeNetworks *string + duration *time.Duration + inPubPath *string + outKeyPath *string + outCertPath *string + outQRPath *string + groups *string + + p11url *string + + // Deprecated options + ip *string + subnets *string } func newSignFlags() *signFlags { sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)} sf.set.Usage = func() {} + sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.") sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key") sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert") sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname") - sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert") + sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert") + sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for") sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key") sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to") sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to") sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups") - sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for") - return &sf + sf.p11url = p11Flag(sf.set) + sf.ip = sf.set.String("ip", "", "Deprecated, see -networks") + sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks") + return &sf } func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error { @@ -56,8 +69,12 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return err } - if err := mustFlagString("ca-key", sf.caKeyPath); err != nil { - return err + isP11 := len(*sf.p11url) > 0 + + if !isP11 { + if err := mustFlagString("ca-key", sf.caKeyPath); err != nil { + return err + } } if err := mustFlagString("ca-crt", sf.caCertPath); err != nil { return err @@ -65,50 +82,67 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) if err := mustFlagString("name", sf.name); err != nil { return err } - if err := mustFlagString("ip", sf.ip); err != nil { - return err - } - if *sf.inPubPath != "" && *sf.outKeyPath != "" { + if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" { return newHelpErrorf("cannot set both -in-pub and -out-key") } - rawCAKey, err := os.ReadFile(*sf.caKeyPath) - if err != nil { - return fmt.Errorf("error while reading ca-key: %s", err) + var v4Networks []netip.Prefix + var v6Networks []netip.Prefix + if *sf.networks == "" && *sf.ip != "" { + // Pull up deprecated -ip flag if needed + *sf.networks = *sf.ip + } + + if len(*sf.networks) == 0 { + return newHelpErrorf("-networks is required") + } + + version := cert.Version(*sf.version) + if version != 0 && version != cert.Version1 && version != cert.Version2 { + return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) } var curve cert.Curve var caKey []byte - // naively attempt to decode the private key as though it is not encrypted - caKey, _, curve, err = cert.UnmarshalSigningPrivateKey(rawCAKey) - if err == cert.ErrPrivateKeyEncrypted { - // ask for a passphrase until we get one - var passphrase []byte - for i := 0; i < 5; i++ { - out.Write([]byte("Enter passphrase: ")) - passphrase, err = pr.ReadPassword() + if !isP11 { + var rawCAKey []byte + rawCAKey, err := os.ReadFile(*sf.caKeyPath) - if err == ErrNoTerminal { - return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") - } else if err != nil { - return fmt.Errorf("error reading password: %s", err) - } - - if len(passphrase) > 0 { - break - } - } - if len(passphrase) == 0 { - return fmt.Errorf("cannot open encrypted ca-key without passphrase") - } - - curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey) if err != nil { - return fmt.Errorf("error while parsing encrypted ca-key: %s", err) + return fmt.Errorf("error while reading ca-key: %s", err) + } + + // naively attempt to decode the private key as though it is not encrypted + caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey) + if errors.Is(err, cert.ErrPrivateKeyEncrypted) { + // ask for a passphrase until we get one + var passphrase []byte + for i := 0; i < 5; i++ { + out.Write([]byte("Enter passphrase: ")) + passphrase, err = pr.ReadPassword() + + if errors.Is(err, ErrNoTerminal) { + return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") + } else if err != nil { + return fmt.Errorf("error reading password: %s", err) + } + + if len(passphrase) > 0 { + break + } + } + if len(passphrase) == 0 { + return fmt.Errorf("cannot open encrypted ca-key without passphrase") + } + + curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey) + if err != nil { + return fmt.Errorf("error while parsing encrypted ca-key: %s", err) + } + } else if err != nil { + return fmt.Errorf("error while parsing ca-key: %s", err) } - } else if err != nil { - return fmt.Errorf("error while parsing ca-key: %s", err) } rawCACert, err := os.ReadFile(*sf.caCertPath) @@ -116,18 +150,15 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("error while reading ca-crt: %s", err) } - caCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCACert) + caCert, _, err := cert.UnmarshalCertificateFromPEM(rawCACert) if err != nil { return fmt.Errorf("error while parsing ca-crt: %s", err) } - if err := caCert.VerifyPrivateKey(curve, caKey); err != nil { - return fmt.Errorf("refusing to sign, root certificate does not match private key") - } - - issuer, err := caCert.Sha256Sum() - if err != nil { - return fmt.Errorf("error while getting -ca-crt fingerprint: %s", err) + if !isP11 { + if err := caCert.VerifyPrivateKey(curve, caKey); err != nil { + return fmt.Errorf("refusing to sign, root certificate does not match private key") + } } if caCert.Expired(time.Now()) { @@ -136,19 +167,53 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) // if no duration is given, expire one second before the root expires if *sf.duration <= 0 { - *sf.duration = time.Until(caCert.Details.NotAfter) - time.Second*1 + *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 } - ip, ipNet, err := net.ParseCIDR(*sf.ip) - if err != nil { - return newHelpErrorf("invalid ip definition: %s", err) - } - if ip.To4() == nil { - return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip) - } - ipNet.IP = ip + if *sf.networks != "" { + for _, rs := range strings.Split(*sf.networks, ",") { + rs := strings.Trim(rs, " ") + if rs != "" { + n, err := netip.ParsePrefix(rs) + if err != nil { + return newHelpErrorf("invalid -networks definition: %s", rs) + } - groups := []string{} + if n.Addr().Is4() { + v4Networks = append(v4Networks, n) + } else { + v6Networks = append(v6Networks, n) + } + } + } + } + + var v4UnsafeNetworks []netip.Prefix + var v6UnsafeNetworks []netip.Prefix + if *sf.unsafeNetworks == "" && *sf.subnets != "" { + // Pull up deprecated -subnets flag if needed + *sf.unsafeNetworks = *sf.subnets + } + + if *sf.unsafeNetworks != "" { + for _, rs := range strings.Split(*sf.unsafeNetworks, ",") { + rs := strings.Trim(rs, " ") + if rs != "" { + n, err := netip.ParsePrefix(rs) + if err != nil { + return newHelpErrorf("invalid -unsafe-networks definition: %s", rs) + } + + if n.Addr().Is4() { + v4UnsafeNetworks = append(v4UnsafeNetworks, n) + } else { + v6UnsafeNetworks = append(v6UnsafeNetworks, n) + } + } + } + } + + var groups []string if *sf.groups != "" { for _, rg := range strings.Split(*sf.groups, ",") { g := strings.TrimSpace(rg) @@ -158,60 +223,43 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } } - subnets := []*net.IPNet{} - if *sf.subnets != "" { - for _, rs := range strings.Split(*sf.subnets, ",") { - rs := strings.Trim(rs, " ") - if rs != "" { - _, s, err := net.ParseCIDR(rs) - if err != nil { - return newHelpErrorf("invalid subnet definition: %s", err) - } - if s.IP.To4() == nil { - return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs) - } - subnets = append(subnets, s) - } + var pub, rawPriv []byte + var p11Client *pkclient.PKClient + + if isP11 { + curve = cert.Curve_P256 + p11Client, err = pkclient.FromUrl(*sf.p11url) + if err != nil { + return fmt.Errorf("error while creating PKCS#11 client: %w", err) } + defer func(client *pkclient.PKClient) { + _ = client.Close() + }(p11Client) } - var pub, rawPriv []byte if *sf.inPubPath != "" { + var pubCurve cert.Curve rawPub, err := os.ReadFile(*sf.inPubPath) if err != nil { return fmt.Errorf("error while reading in-pub: %s", err) } - var pubCurve cert.Curve - pub, _, pubCurve, err = cert.UnmarshalPublicKey(rawPub) + + pub, _, pubCurve, err = cert.UnmarshalPublicKeyFromPEM(rawPub) if err != nil { return fmt.Errorf("error while parsing in-pub: %s", err) } if pubCurve != curve { return fmt.Errorf("curve of in-pub does not match ca") } + } else if isP11 { + pub, err = p11Client.GetPubKey() + if err != nil { + return fmt.Errorf("error while getting public key with PKCS#11: %w", err) + } } else { pub, rawPriv = newKeypair(curve) } - nc := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: *sf.name, - Ips: []*net.IPNet{ipNet}, - Groups: groups, - Subnets: subnets, - NotBefore: time.Now(), - NotAfter: time.Now().Add(*sf.duration), - PublicKey: pub, - IsCA: false, - Issuer: issuer, - Curve: curve, - }, - } - - if err := nc.CheckRootConstrains(caCert); err != nil { - return fmt.Errorf("refusing to sign, root certificate constraints violated: %s", err) - } - if *sf.outKeyPath == "" { *sf.outKeyPath = *sf.name + ".key" } @@ -224,25 +272,105 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) } - err = nc.Sign(curve, caKey) - if err != nil { - return fmt.Errorf("error while signing: %s", err) + var crts []cert.Certificate + + notBefore := time.Now() + notAfter := notBefore.Add(*sf.duration) + + if version == 0 || version == cert.Version1 { + // Make sure we at least have an ip + if len(v4Networks) != 1 { + return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address") + } + + if version == cert.Version1 { + // If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses + if len(v6Networks) > 0 { + return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4") + } + + if len(v6UnsafeNetworks) > 0 { + return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4") + } + } + + t := &cert.TBSCertificate{ + Version: cert.Version1, + Name: *sf.name, + Networks: []netip.Prefix{v4Networks[0]}, + Groups: groups, + UnsafeNetworks: v4UnsafeNetworks, + NotBefore: notBefore, + NotAfter: notAfter, + PublicKey: pub, + IsCA: false, + Curve: curve, + } + + var nc cert.Certificate + if p11Client == nil { + nc, err = t.Sign(caCert, curve, caKey) + if err != nil { + return fmt.Errorf("error while signing: %w", err) + } + } else { + nc, err = t.SignWith(caCert, curve, p11Client.SignASN1) + if err != nil { + return fmt.Errorf("error while signing with PKCS#11: %w", err) + } + } + + crts = append(crts, nc) } - if *sf.inPubPath == "" { + if version == 0 || version == cert.Version2 { + t := &cert.TBSCertificate{ + Version: cert.Version2, + Name: *sf.name, + Networks: append(v4Networks, v6Networks...), + Groups: groups, + UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...), + NotBefore: notBefore, + NotAfter: notAfter, + PublicKey: pub, + IsCA: false, + Curve: curve, + } + + var nc cert.Certificate + if p11Client == nil { + nc, err = t.Sign(caCert, curve, caKey) + if err != nil { + return fmt.Errorf("error while signing: %w", err) + } + } else { + nc, err = t.SignWith(caCert, curve, p11Client.SignASN1) + if err != nil { + return fmt.Errorf("error while signing with PKCS#11: %w", err) + } + } + + crts = append(crts, nc) + } + + if !isP11 && *sf.inPubPath == "" { if _, err := os.Stat(*sf.outKeyPath); err == nil { return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath) } - err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600) + err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } } - b, err := nc.MarshalToPEM() - if err != nil { - return fmt.Errorf("error while marshalling certificate: %s", err) + var b []byte + for _, c := range crts { + sb, err := c.MarshalPEM() + if err != nil { + return fmt.Errorf("error while marshalling certificate: %s", err) + } + b = append(b, sb...) } err = os.WriteFile(*sf.outCertPath, b, 0600) diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index adf83a2..b2bba76 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -13,11 +13,10 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/crypto/ed25519" ) -//TODO: test file permissions - func Test_signSummary(t *testing.T) { assert.Equal(t, "sign : create and sign a certificate", signSummary()) } @@ -39,17 +38,24 @@ func Test_signHelp(t *testing.T) { " -in-pub string\n"+ " \tOptional (if out-key not set): path to read a previously generated public key\n"+ " -ip string\n"+ - " \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+ + " \tDeprecated, see -networks\n"+ " -name string\n"+ " \tRequired: name of the cert, usually a hostname\n"+ + " -networks string\n"+ + " \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+ " -out-crt string\n"+ " \tOptional: path to write the certificate to\n"+ " -out-key string\n"+ " \tOptional (if in-pub not set): path to write the private key to\n"+ " -out-qr string\n"+ " \tOptional: output a qr code image (png) of the certificate\n"+ + optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ " -subnets string\n"+ - " \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n", + " \tDeprecated, see -unsafe-networks\n"+ + " -unsafe-networks string\n"+ + " \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+ + " -version uint\n"+ + " \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n", ob.String(), ) } @@ -76,20 +82,20 @@ func Test_signCert(t *testing.T) { // required args assertHelpError(t, signCert( - []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, + []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, ), "-name is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) assertHelpError(t, signCert( - []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, - ), "-ip is required") + []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, + ), "-networks is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // cannot set -in-pub and -out-key assertHelpError(t, signCert( - []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw, + []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw, ), "cannot set both -in-pub and -out-key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -97,18 +103,18 @@ func Test_signCert(t *testing.T) { // failed to read key ob.Reset() eb.Reset() - args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) + args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) // failed to unmarshal key ob.Reset() eb.Reset() caKeyF, err := os.CreateTemp("", "sign-cert.key") - assert.Nil(t, err) + require.NoError(t, err) defer os.Remove(caKeyF.Name()) - args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") + args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -116,11 +122,11 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) - caKeyF.Write(cert.MarshalEd25519PrivateKey(caPriv)) + caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv)) // failed to read cert - args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) + args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -128,30 +134,22 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() caCrtF, err := os.CreateTemp("", "sign-cert.crt") - assert.Nil(t, err) + require.NoError(t, err) defer os.Remove(caCrtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // write a proper ca cert for later - ca := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "ca", - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Minute * 200), - PublicKey: caPub, - IsCA: true, - }, - } - b, _ := ca.MarshalToPEM() + ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil) + b, _ := ca.MarshalPEM() caCrtF.Write(b) // failed to read pub - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} + require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -159,11 +157,11 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() inPubF, err := os.CreateTemp("", "in.pub") - assert.Nil(t, err) + require.NoError(t, err) defer os.Remove(inPubF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} + require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -171,116 +169,124 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() inPub, _ := x25519Keypair() - inPubF.Write(cert.MarshalX25519PublicKey(inPub)) + inPubF.Write(cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub)) // bad ip cidr ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: invalid CIDR address: a1.1.1.1/24") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + ob.Reset() + eb.Reset() + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // bad subnet cidr ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: invalid CIDR address: a") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // mismatched ca key _, caPriv2, _ := ed25519.GenerateKey(rand.Reader) caKeyF2, err := os.CreateTemp("", "sign-cert-2.key") - assert.Nil(t, err) + require.NoError(t, err) defer os.Remove(caKeyF2.Name()) - caKeyF2.Write(cert.MarshalEd25519PrivateKey(caPriv2)) + caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2)) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} + require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // failed key write ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} + require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.Nil(t, err) + require.NoError(t, err) os.Remove(keyF.Name()) // failed cert write ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} + require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) os.Remove(keyF.Name()) // create temp cert file crtF, err := os.CreateTemp("", "test.crt") - assert.Nil(t, err) + require.NoError(t, err) os.Remove(crtF.Name()) // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb, nopw)) + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + require.NoError(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // read cert and key files rb, _ := os.ReadFile(keyF.Name()) - lKey, b, err := cert.UnmarshalX25519PrivateKey(rb) - assert.Len(t, b, 0) - assert.Nil(t, err) + lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) + assert.Equal(t, cert.Curve_CURVE25519, curve) + assert.Empty(t, b) + require.NoError(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(crtF.Name()) - lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb) - assert.Len(t, b, 0) - assert.Nil(t, err) + lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) + assert.Empty(t, b) + require.NoError(t, err) - assert.Equal(t, "test", lCrt.Details.Name) - assert.Equal(t, "1.1.1.1/24", lCrt.Details.Ips[0].String()) - assert.Len(t, lCrt.Details.Ips, 1) - assert.False(t, lCrt.Details.IsCA) - assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups) - assert.Len(t, lCrt.Details.Subnets, 3) - assert.Len(t, lCrt.Details.PublicKey, 32) - assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore)) + assert.Equal(t, "test", lCrt.Name()) + assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String()) + assert.Len(t, lCrt.Networks(), 1) + assert.False(t, lCrt.IsCA()) + assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups()) + assert.Len(t, lCrt.UnsafeNetworks(), 3) + assert.Len(t, lCrt.PublicKey(), 32) + assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore())) sns := []string{} - for _, sn := range lCrt.Details.Subnets { + for _, sn := range lCrt.UnsafeNetworks() { sns = append(sns, sn.String()) } assert.Equal(t, []string{"10.1.1.1/32", "10.2.2.2/32", "10.5.5.5/32"}, sns) - issuer, _ := ca.Sha256Sum() - assert.Equal(t, issuer, lCrt.Details.Issuer) + issuer, _ := ca.Fingerprint() + assert.Equal(t, issuer, lCrt.Issuer()) assert.True(t, lCrt.CheckSignature(caPub)) @@ -289,53 +295,55 @@ func Test_signCert(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} - assert.Nil(t, signCert(args, ob, eb, nopw)) + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} + require.NoError(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // read cert file and check pub key matches in-pub rb, _ = os.ReadFile(crtF.Name()) - lCrt, b, err = cert.UnmarshalNebulaCertificateFromPEM(rb) - assert.Len(t, b, 0) - assert.Nil(t, err) - assert.Equal(t, lCrt.Details.PublicKey, inPub) + lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb) + assert.Empty(t, b) + require.NoError(t, err) + assert.Equal(t, lCrt.PublicKey(), inPub) // test refuse to sign cert with duration beyond root ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate constraints violated: certificate expires after signing certificate") + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + require.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb, nopw)) + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + require.NoError(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing key file os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb, nopw)) + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + require.NoError(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file os.Remove(keyF.Name()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -348,11 +356,11 @@ func Test_signCert(t *testing.T) { eb.Reset() caKeyF, err = os.CreateTemp("", "sign-cert.key") - assert.Nil(t, err) + require.NoError(t, err) defer os.Remove(caKeyF.Name()) caCrtF, err = os.CreateTemp("", "sign-cert.crt") - assert.Nil(t, err) + require.NoError(t, err) defer os.Remove(caCrtF.Name()) // generate the encrypted key @@ -361,21 +369,13 @@ func Test_signCert(t *testing.T) { b, _ = cert.EncryptAndMarshalSigningPrivateKey(cert.Curve_CURVE25519, caPriv, passphrase, kdfParams) caKeyF.Write(b) - ca = cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "ca", - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Minute * 200), - PublicKey: caPub, - IsCA: true, - }, - } - b, _ = ca.MarshalToPEM() + ca, _ = NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil) + b, _ = ca.MarshalPEM() caCrtF.Write(b) // test with the proper password - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb, testpw)) + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + require.NoError(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -384,8 +384,8 @@ func Test_signCert(t *testing.T) { eb.Reset() testpw.password = []byte("invalid password") - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Error(t, signCert(args, ob, eb, testpw)) + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + require.Error(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -393,8 +393,8 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Error(t, signCert(args, ob, eb, nopw)) + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + require.Error(t, signCert(args, ob, eb, nopw)) // normally the user hitting enter on the prompt would add newlines between these assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -403,8 +403,8 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Error(t, signCert(args, ob, eb, errpw)) + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + require.Error(t, signCert(args, ob, eb, errpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) } diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go index c955913..bea4d1d 100644 --- a/cmd/nebula-cert/verify.go +++ b/cmd/nebula-cert/verify.go @@ -1,6 +1,7 @@ package main import ( + "errors" "flag" "fmt" "io" @@ -41,14 +42,14 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { rawCACert, err := os.ReadFile(*vf.caPath) if err != nil { - return fmt.Errorf("error while reading ca: %s", err) + return fmt.Errorf("error while reading ca: %w", err) } caPool := cert.NewCAPool() for { - rawCACert, err = caPool.AddCACertificate(rawCACert) + rawCACert, err = caPool.AddCAFromPEM(rawCACert) if err != nil { - return fmt.Errorf("error while adding ca cert to pool: %s", err) + return fmt.Errorf("error while adding ca cert to pool: %w", err) } if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" { @@ -58,20 +59,30 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { rawCert, err := os.ReadFile(*vf.certPath) if err != nil { - return fmt.Errorf("unable to read crt; %s", err) + return fmt.Errorf("unable to read crt: %w", err) + } + var errs []error + for { + if len(rawCert) == 0 { + break + } + c, extra, err := cert.UnmarshalCertificateFromPEM(rawCert) + if err != nil { + return fmt.Errorf("error while parsing crt: %w", err) + } + rawCert = extra + _, err = caPool.VerifyCertificate(time.Now(), c) + if err != nil { + switch { + case errors.Is(err, cert.ErrCaNotFound): + errs = append(errs, fmt.Errorf("error while verifying certificate v%d %s with issuer %s: %w", c.Version(), c.Name(), c.Issuer(), err)) + default: + errs = append(errs, fmt.Errorf("error while verifying certificate %+v: %w", c, err)) + } + } } - c, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) - if err != nil { - return fmt.Errorf("error while parsing crt: %s", err) - } - - good, err := c.Verify(time.Now(), caPool) - if !good { - return err - } - - return nil + return errors.Join(errs...) } func verifySummary() string { @@ -80,7 +91,7 @@ func verifySummary() string { func verifyHelp(out io.Writer) { vf := newVerifyFlags() - out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) + _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) vf.set.SetOutput(out) vf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index f0f4c78..f555e5f 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -9,6 +9,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/crypto/ed25519" ) @@ -37,105 +38,87 @@ func Test_verify(t *testing.T) { // required args assertHelpError(t, verify([]string{"-ca", "derp"}, ob, eb), "-crt is required") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) assertHelpError(t, verify([]string{"-crt", "derp"}, ob, eb), "-ca is required") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // no ca at path ob.Reset() eb.Reset() err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError) // invalid ca at path ob.Reset() eb.Reset() caFile, err := os.CreateTemp("", "verify-ca") - assert.Nil(t, err) + require.NoError(t, err) defer os.Remove(caFile.Name()) caFile.WriteString("-----BEGIN NOPE-----") err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") // make a ca for later caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) - ca := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test-ca", - NotBefore: time.Now().Add(time.Hour * -1), - NotAfter: time.Now().Add(time.Hour * 2), - PublicKey: caPub, - IsCA: true, - }, - } - ca.Sign(cert.Curve_CURVE25519, caPriv) - b, _ := ca.MarshalToPEM() + ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil) + b, _ := ca.MarshalPEM() caFile.Truncate(0) caFile.Seek(0, 0) caFile.Write(b) // no crt at path err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "unable to read crt; open does_not_exist: "+NoSuchFileError) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError) // invalid crt at path ob.Reset() eb.Reset() certFile, err := os.CreateTemp("", "verify-cert") - assert.Nil(t, err) + require.NoError(t, err) defer os.Remove(certFile.Name()) certFile.WriteString("-----BEGIN NOPE-----") err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block") // unverifiable cert at path - _, badPriv, _ := ed25519.GenerateKey(rand.Reader) - certPub, _ := x25519Keypair() - signer, _ := ca.Sha256Sum() - crt := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test-cert", - NotBefore: time.Now().Add(time.Hour * -1), - NotAfter: time.Now().Add(time.Hour), - PublicKey: certPub, - IsCA: false, - Issuer: signer, - }, + crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) + // Slightly evil hack to modify the certificate after it was sealed to generate an invalid signature + pub := crt.PublicKey() + for i, _ := range pub { + pub[i] = 0 } - - crt.Sign(cert.Curve_CURVE25519, badPriv) - b, _ = crt.MarshalToPEM() + b, _ = crt.MarshalPEM() certFile.Truncate(0) certFile.Seek(0, 0) certFile.Write(b) err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "certificate signature did not match") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + require.ErrorIs(t, err, cert.ErrSignatureMismatch) // verified cert at path - crt.Sign(cert.Curve_CURVE25519, caPriv) - b, _ = crt.MarshalToPEM() + crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) + b, _ = crt.MarshalPEM() certFile.Truncate(0) certFile.Seek(0, 0) certFile.Write(b) err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) - assert.Nil(t, err) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + require.NoError(t, err) } diff --git a/config/config.go b/config/config.go index 1aea832..b1531e9 100644 --- a/config/config.go +++ b/config/config.go @@ -17,14 +17,14 @@ import ( "dario.cat/mergo" "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) type C struct { path string files []string - Settings map[interface{}]interface{} - oldSettings map[interface{}]interface{} + Settings map[string]any + oldSettings map[string]any callbacks []func(*C) l *logrus.Logger reloadLock sync.Mutex @@ -32,7 +32,7 @@ type C struct { func NewC(l *logrus.Logger) *C { return &C{ - Settings: make(map[interface{}]interface{}), + Settings: make(map[string]any), l: l, } } @@ -92,8 +92,8 @@ func (c *C) HasChanged(k string) bool { } var ( - nv interface{} - ov interface{} + nv any + ov any ) if k == "" { @@ -147,7 +147,7 @@ func (c *C) ReloadConfig() { c.reloadLock.Lock() defer c.reloadLock.Unlock() - c.oldSettings = make(map[interface{}]interface{}) + c.oldSettings = make(map[string]any) for k, v := range c.Settings { c.oldSettings[k] = v } @@ -167,7 +167,7 @@ func (c *C) ReloadConfigString(raw string) error { c.reloadLock.Lock() defer c.reloadLock.Unlock() - c.oldSettings = make(map[interface{}]interface{}) + c.oldSettings = make(map[string]any) for k, v := range c.Settings { c.oldSettings[k] = v } @@ -201,7 +201,7 @@ func (c *C) GetStringSlice(k string, d []string) []string { return d } - rv, ok := r.([]interface{}) + rv, ok := r.([]any) if !ok { return d } @@ -215,13 +215,13 @@ func (c *C) GetStringSlice(k string, d []string) []string { } // GetMap will get the map for k or return the default d if not found or invalid -func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} { +func (c *C) GetMap(k string, d map[string]any) map[string]any { r := c.Get(k) if r == nil { return d } - v, ok := r.(map[interface{}]interface{}) + v, ok := r.(map[string]any) if !ok { return d } @@ -266,6 +266,22 @@ func (c *C) GetBool(k string, d bool) bool { return v } +func AsBool(v any) (value bool, ok bool) { + switch x := v.(type) { + case bool: + return x, true + case string: + switch x { + case "y", "yes": + return true, true + case "n", "no": + return false, true + } + } + + return false, false +} + // GetDuration will get the duration for k or return the default d if not found or invalid func (c *C) GetDuration(k string, d time.Duration) time.Duration { r := c.GetString(k, "") @@ -276,7 +292,7 @@ func (c *C) GetDuration(k string, d time.Duration) time.Duration { return v } -func (c *C) Get(k string) interface{} { +func (c *C) Get(k string) any { return c.get(k, c.Settings) } @@ -284,10 +300,10 @@ func (c *C) IsSet(k string) bool { return c.get(k, c.Settings) != nil } -func (c *C) get(k string, v interface{}) interface{} { +func (c *C) get(k string, v any) any { parts := strings.Split(k, ".") for _, p := range parts { - m, ok := v.(map[interface{}]interface{}) + m, ok := v.(map[string]any) if !ok { return nil } @@ -346,7 +362,7 @@ func (c *C) addFile(path string, direct bool) error { } func (c *C) parseRaw(b []byte) error { - var m map[interface{}]interface{} + var m map[string]any err := yaml.Unmarshal(b, &m) if err != nil { @@ -358,7 +374,7 @@ func (c *C) parseRaw(b []byte) error { } func (c *C) parse() error { - var m map[interface{}]interface{} + var m map[string]any for _, path := range c.files { b, err := os.ReadFile(path) @@ -366,7 +382,7 @@ func (c *C) parse() error { return err } - var nm map[interface{}]interface{} + var nm map[string]any err = yaml.Unmarshal(b, &nm) if err != nil { return err diff --git a/config/config_test.go b/config/config_test.go index fa94393..ec5a4b0 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -10,7 +10,7 @@ import ( "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) func TestConfig_Load(t *testing.T) { @@ -19,40 +19,37 @@ func TestConfig_Load(t *testing.T) { // invalid yaml c := NewC(l) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) - assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") + require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[string]interface {}") // simple multi config merge c = NewC(l) os.RemoveAll(dir) os.Mkdir(dir, 0755) - assert.Nil(t, err) + require.NoError(t, err) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644) - assert.Nil(t, c.Load(dir)) - expected := map[interface{}]interface{}{ - "outer": map[interface{}]interface{}{ + require.NoError(t, c.Load(dir)) + expected := map[string]any{ + "outer": map[string]any{ "inner": "override", }, "new": "hi", } assert.Equal(t, expected, c.Settings) - - //TODO: test symlinked file - //TODO: test symlinked directory } func TestConfig_Get(t *testing.T) { l := test.NewLogger() // test simple type c := NewC(l) - c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"} + c.Settings["firewall"] = map[string]any{"outbound": "hi"} assert.Equal(t, "hi", c.Get("firewall.outbound")) // test complex type - inner := []map[interface{}]interface{}{{"port": "1", "code": "2"}} - c.Settings["firewall"] = map[interface{}]interface{}{"outbound": inner} + inner := []map[string]any{{"port": "1", "code": "2"}} + c.Settings["firewall"] = map[string]any{"outbound": inner} assert.EqualValues(t, inner, c.Get("firewall.outbound")) // test missing @@ -62,7 +59,7 @@ func TestConfig_Get(t *testing.T) { func TestConfig_GetStringSlice(t *testing.T) { l := test.NewLogger() c := NewC(l) - c.Settings["slice"] = []interface{}{"one", "two"} + c.Settings["slice"] = []any{"one", "two"} assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{})) } @@ -70,28 +67,28 @@ func TestConfig_GetBool(t *testing.T) { l := test.NewLogger() c := NewC(l) c.Settings["bool"] = true - assert.Equal(t, true, c.GetBool("bool", false)) + assert.True(t, c.GetBool("bool", false)) c.Settings["bool"] = "true" - assert.Equal(t, true, c.GetBool("bool", false)) + assert.True(t, c.GetBool("bool", false)) c.Settings["bool"] = false - assert.Equal(t, false, c.GetBool("bool", true)) + assert.False(t, c.GetBool("bool", true)) c.Settings["bool"] = "false" - assert.Equal(t, false, c.GetBool("bool", true)) + assert.False(t, c.GetBool("bool", true)) c.Settings["bool"] = "Y" - assert.Equal(t, true, c.GetBool("bool", false)) + assert.True(t, c.GetBool("bool", false)) c.Settings["bool"] = "yEs" - assert.Equal(t, true, c.GetBool("bool", false)) + assert.True(t, c.GetBool("bool", false)) c.Settings["bool"] = "N" - assert.Equal(t, false, c.GetBool("bool", true)) + assert.False(t, c.GetBool("bool", true)) c.Settings["bool"] = "nO" - assert.Equal(t, false, c.GetBool("bool", true)) + assert.False(t, c.GetBool("bool", true)) } func TestConfig_HasChanged(t *testing.T) { @@ -104,14 +101,14 @@ func TestConfig_HasChanged(t *testing.T) { // Test key change c = NewC(l) c.Settings["test"] = "hi" - c.oldSettings = map[interface{}]interface{}{"test": "no"} + c.oldSettings = map[string]any{"test": "no"} assert.True(t, c.HasChanged("test")) assert.True(t, c.HasChanged("")) // No key change c = NewC(l) c.Settings["test"] = "hi" - c.oldSettings = map[interface{}]interface{}{"test": "hi"} + c.oldSettings = map[string]any{"test": "hi"} assert.False(t, c.HasChanged("test")) assert.False(t, c.HasChanged("")) } @@ -120,11 +117,11 @@ func TestConfig_ReloadConfig(t *testing.T) { l := test.NewLogger() done := make(chan bool, 1) dir, err := os.MkdirTemp("", "config-test") - assert.Nil(t, err) + require.NoError(t, err) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) c := NewC(l) - assert.Nil(t, c.Load(dir)) + require.NoError(t, c.Load(dir)) assert.False(t, c.HasChanged("outer.inner")) assert.False(t, c.HasChanged("outer")) @@ -187,11 +184,11 @@ firewall: `), } - var m map[any]any + var m map[string]any // merge the same way config.parse() merges for _, b := range configs { - var nm map[any]any + var nm map[string]any err := yaml.Unmarshal(b, &nm) require.NoError(t, err) @@ -208,15 +205,15 @@ firewall: t.Logf("Merged Config as YAML:\n%s", mYaml) // If a bug is present, some items might be replaced instead of merged like we expect - expected := map[any]any{ - "firewall": map[any]any{ + expected := map[string]any{ + "firewall": map[string]any{ "inbound": []any{ - map[any]any{"host": "any", "port": "any", "proto": "icmp"}, - map[any]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"}, - map[any]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}}, + map[string]any{"host": "any", "port": "any", "proto": "icmp"}, + map[string]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"}, + map[string]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}}, "outbound": []any{ - map[any]any{"host": "any", "port": "any", "proto": "any"}}}, - "listen": map[any]any{ + map[string]any{"host": "any", "port": "any", "proto": "any"}}}, + "listen": map[string]any{ "host": "0.0.0.0", "port": 4242, }, diff --git a/connection_manager.go b/connection_manager.go index f9e1b71..19c6223 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -3,14 +3,14 @@ package nebula import ( "bytes" "context" + "encoding/binary" + "net/netip" "time" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) type trafficDecision int @@ -182,7 +182,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, case deleteTunnel: if n.hostMap.DeleteHostInfo(hostinfo) { // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap - n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp) + n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs) } case closeTunnel: @@ -220,11 +220,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) relayFor := oldhostinfo.relayState.CopyAllRelayFor() for _, r := range relayFor { - existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp) + existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr) var index uint32 - var relayFrom iputil.VpnIp - var relayTo iputil.VpnIp + var relayFrom netip.Addr + var relayTo netip.Addr switch { case ok && existing.State == Established: // This relay already exists in newhostinfo, then do nothing. @@ -234,11 +234,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) index = existing.LocalIndex switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnIp - relayTo = existing.PeerIp + relayFrom = n.intf.myVpnAddrs[0] + relayTo = existing.PeerAddr case ForwardingType: - relayFrom = existing.PeerIp - relayTo = newhostinfo.vpnIp + relayFrom = existing.PeerAddr + relayTo = newhostinfo.vpnAddrs[0] default: // should never happen } @@ -252,18 +252,18 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) n.relayUsedLock.RUnlock() // The relay doesn't exist at all; create some relay state and send the request. var err error - index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested) + index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested) if err != nil { n.l.WithError(err).Error("failed to migrate relay to new hostinfo") continue } switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnIp - relayTo = r.PeerIp + relayFrom = n.intf.myVpnAddrs[0] + relayTo = r.PeerAddr case ForwardingType: - relayFrom = r.PeerIp - relayTo = newhostinfo.vpnIp + relayFrom = r.PeerAddr + relayTo = newhostinfo.vpnAddrs[0] default: // should never happen } @@ -273,20 +273,43 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) req := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, - RelayFromIp: uint32(relayFrom), - RelayToIp: uint32(relayTo), } + + switch newhostinfo.GetCert().Certificate.Version() { + case cert.Version1: + if !relayFrom.Is4() { + n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !relayTo.Is4() { + n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := relayFrom.As4() + req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = relayTo.As4() + req.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + req.RelayFromAddr = netAddrToProtoAddr(relayFrom) + req.RelayToAddr = netAddrToProtoAddr(relayTo) + default: + newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay") + continue + } + msg, err := req.Marshal() if err != nil { n.l.WithError(err).Error("failed to marshal Control message to migrate relay") } else { n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) n.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(req.RelayFromIp), - "relayTo": iputil.VpnIp(req.RelayToIp), + "relayFrom": req.RelayFromAddr, + "relayTo": req.RelayToAddr, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex, - "vpnIp": newhostinfo.vpnIp}). + "vpnAddrs": newhostinfo.vpnAddrs}). Info("send CreateRelayRequest") } } @@ -308,7 +331,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time return closeTunnel, hostinfo, nil } - primary := n.hostMap.Hosts[hostinfo.vpnIp] + primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]] mainHostInfo := true if primary != nil && primary != hostinfo { mainHostInfo = false @@ -402,21 +425,24 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // Let's sort this out. - if current.vpnIp < n.intf.myVpnIp { - // Only one side should flip primary because if both flip then we may never resolve to a single tunnel. - // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping. - // The remotes vpn ip is lower than mine. I will not flip. + // Only one side should swap because if both swap then we may never resolve to a single tunnel. + // vpn addr is static across all tunnels for this host pair so lets + // use that to determine if we should consider swapping. + if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 { + // Their primary vpn addr is less than mine. Do not swap. return false } - certState := n.intf.pki.GetCertState() - return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature) + crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version()) + // If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things + // settle down. + return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature()) } func (n *connectionManager) swapPrimary(current, primary *HostInfo) { n.hostMap.Lock() // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake. - if n.hostMap.Hosts[current.vpnIp] == primary { + if n.hostMap.Hosts[current.vpnAddrs[0]] == primary { n.hostMap.unlockedMakePrimary(current) } n.hostMap.Unlock() @@ -431,8 +457,9 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn return false } - valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool()) - if valid { + caPool := n.intf.pki.GetCAPool() + err := caPool.VerifyCachedCertificate(now, remoteCert) + if err == nil { return false } @@ -441,9 +468,8 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn return false } - fingerprint, _ := remoteCert.Sha256Sum() hostinfo.logger(n.l).WithError(err). - WithField("fingerprint", fingerprint). + WithField("fingerprint", remoteCert.Fingerprint). Info("Remote certificate is no longer valid, tearing down the tunnel") return true @@ -456,26 +482,29 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { } if n.punchy.GetTargetEverything() { - hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) { + hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { n.metricsTxPunchy.Inc(1) n.intf.outside.WriteTo([]byte{1}, addr) }) - } else if hostinfo.remote != nil { + } else if hostinfo.remote.IsValid() { n.metricsTxPunchy.Inc(1) n.intf.outside.WriteTo([]byte{1}, hostinfo.remote) } } func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { - certState := n.intf.pki.GetCertState() - if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) { + cs := n.intf.pki.getCertState() + curCrt := hostinfo.ConnectionState.myCert + myCrt := cs.getCertificate(curCrt.Version()) + if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { + // The current tunnel is using the latest certificate and version, no need to rehandshake. return } - n.l.WithField("vpnIp", hostinfo.vpnIp). + n.l.WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("reason", "local certificate is not current"). Info("Re-handshaking with remote") - n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil) + n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) } diff --git a/connection_manager_test.go b/connection_manager_test.go index f50bcf8..2c9baa1 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -4,29 +4,27 @@ import ( "context" "crypto/ed25519" "crypto/rand" - "net" + "net/netip" "testing" "time" "github.com/flynn/noise" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -var vpnIp iputil.VpnIp - func newTestLighthouse() *LightHouse { lh := &LightHouse{ l: test.NewLogger(), - addrMap: map[iputil.VpnIp]*RemoteList{}, - queryChan: make(chan iputil.VpnIp, 10), + addrMap: map[netip.Addr]*RemoteList{}, + queryChan: make(chan netip.Addr, 10), } - lighthouses := map[iputil.VpnIp]struct{}{} - staticList := map[iputil.VpnIp]struct{}{} + lighthouses := map[netip.Addr]struct{}{} + staticList := map[netip.Addr]struct{}{} lh.lighthouses.Store(&lighthouses) lh.staticList.Store(&staticList) @@ -37,20 +35,19 @@ func newTestLighthouse() *LightHouse { func Test_NewConnectionManagerTest(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) - preferredRanges := []*net.IPNet{localrange} + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnIp := netip.MustParseAddr("172.1.1.2") + preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects - hostMap := newHostMap(l, vpncidr) + hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &cert.NebulaCertificate{}, - RawCertificateNoKey: []byte{}, + defaultVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() @@ -77,12 +74,12 @@ func Test_NewConnectionManagerTest(t *testing.T) { // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, localIndexId: 1099, remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - myCert: &cert.NebulaCertificate{}, + myCert: &dummyCert{version: cert.Version1}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -91,7 +88,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { nc.Out(hostinfo.localIndexId) nc.In(hostinfo.localIndexId) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.out, hostinfo.localIndexId) @@ -108,31 +105,31 @@ func Test_NewConnectionManagerTest(t *testing.T) { assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // Do a final traffic check tick, the host should now be removed nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId) } func Test_NewConnectionManagerTest2(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - preferredRanges := []*net.IPNet{localrange} + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnIp := netip.MustParseAddr("172.1.1.2") + preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects - hostMap := newHostMap(l, vpncidr) + hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &cert.NebulaCertificate{}, - RawCertificateNoKey: []byte{}, + defaultVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() @@ -159,12 +156,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, localIndexId: 1099, remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - myCert: &cert.NebulaCertificate{}, + myCert: &dummyCert{version: cert.Version1}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -172,8 +169,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // We saw traffic out to vpnIp nc.Out(hostinfo.localIndexId) nc.In(hostinfo.localIndexId) - assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0]) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded @@ -189,7 +186,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // We saw traffic, should no longer be pending deletion nc.In(hostinfo.localIndexId) @@ -198,7 +195,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) } // Check if we can disconnect the peer. @@ -207,54 +204,48 @@ func Test_NewConnectionManagerTest2(t *testing.T) { func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { now := time.Now() l := test.NewLogger() - ipNet := net.IPNet{ - IP: net.IPv4(172, 1, 1, 2), - Mask: net.IPMask{255, 255, 255, 0}, - } - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - preferredRanges := []*net.IPNet{localrange} - hostMap := newHostMap(l, vpncidr) + + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnIp := netip.MustParseAddr("172.1.1.2") + preferredRanges := []netip.Prefix{localrange} + hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) // Generate keys for CA and peer's cert. pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader) - caCert := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "ca", - NotBefore: now, - NotAfter: now.Add(1 * time.Hour), - IsCA: true, - PublicKey: pubCA, - }, + tbs := &cert.TBSCertificate{ + Version: 1, + Name: "ca", + IsCA: true, + NotBefore: now, + NotAfter: now.Add(1 * time.Hour), + PublicKey: pubCA, } - assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA)) - ncp := &cert.NebulaCAPool{ - CAs: cert.NewCAPool().CAs, - } - ncp.CAs["ca"] = &caCert + caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA) + require.NoError(t, err) + ncp := cert.NewCAPool() + require.NoError(t, ncp.AddCA(caCert)) pubCrt, _, _ := ed25519.GenerateKey(rand.Reader) - peerCert := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host", - Ips: []*net.IPNet{&ipNet}, - Subnets: []*net.IPNet{}, - NotBefore: now, - NotAfter: now.Add(60 * time.Second), - PublicKey: pubCrt, - IsCA: false, - Issuer: "ca", - }, + tbs = &cert.TBSCertificate{ + Version: 1, + Name: "host", + Networks: []netip.Prefix{vpncidr}, + NotBefore: now, + NotAfter: now.Add(60 * time.Second), + PublicKey: pubCrt, } - assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA)) + peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA) + require.NoError(t, err) + + cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &cert.NebulaCertificate{}, - RawCertificateNoKey: []byte{}, + privateKey: []byte{}, + v1Cert: &dummyCert{}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() @@ -280,10 +271,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { ifce.connectionManager = nc hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, ConnectionState: &ConnectionState{ - myCert: &cert.NebulaCertificate{}, - peerCert: &peerCert, + myCert: &dummyCert{}, + peerCert: cachedPeerCert, H: &noise.HandshakeState{}, }, } @@ -303,3 +294,114 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { invalid = nc.isInvalidCertificate(nextTick, hostinfo) assert.True(t, invalid) } + +type dummyCert struct { + version cert.Version + curve cert.Curve + groups []string + isCa bool + issuer string + name string + networks []netip.Prefix + notAfter time.Time + notBefore time.Time + publicKey []byte + signature []byte + unsafeNetworks []netip.Prefix +} + +func (d *dummyCert) Version() cert.Version { + return d.version +} + +func (d *dummyCert) Curve() cert.Curve { + return d.curve +} + +func (d *dummyCert) Groups() []string { + return d.groups +} + +func (d *dummyCert) IsCA() bool { + return d.isCa +} + +func (d *dummyCert) Issuer() string { + return d.issuer +} + +func (d *dummyCert) Name() string { + return d.name +} + +func (d *dummyCert) Networks() []netip.Prefix { + return d.networks +} + +func (d *dummyCert) NotAfter() time.Time { + return d.notAfter +} + +func (d *dummyCert) NotBefore() time.Time { + return d.notBefore +} + +func (d *dummyCert) PublicKey() []byte { + return d.publicKey +} + +func (d *dummyCert) Signature() []byte { + return d.signature +} + +func (d *dummyCert) UnsafeNetworks() []netip.Prefix { + return d.unsafeNetworks +} + +func (d *dummyCert) MarshalForHandshakes() ([]byte, error) { + return nil, nil +} + +func (d *dummyCert) Sign(curve cert.Curve, key []byte) error { + return nil +} + +func (d *dummyCert) CheckSignature(key []byte) bool { + return true +} + +func (d *dummyCert) Expired(t time.Time) bool { + return false +} + +func (d *dummyCert) CheckRootConstraints(signer cert.Certificate) error { + return nil +} + +func (d *dummyCert) VerifyPrivateKey(curve cert.Curve, key []byte) error { + return nil +} + +func (d *dummyCert) String() string { + return "" +} + +func (d *dummyCert) Marshal() ([]byte, error) { + return nil, nil +} + +func (d *dummyCert) MarshalPEM() ([]byte, error) { + return nil, nil +} + +func (d *dummyCert) Fingerprint() (string, error) { + return "", nil +} + +func (d *dummyCert) MarshalJSON() ([]byte, error) { + return nil, nil +} + +func (d *dummyCert) Copy() cert.Certificate { + return d +} diff --git a/connection_state.go b/connection_state.go index 5373f96..e0b1ab3 100644 --- a/connection_state.go +++ b/connection_state.go @@ -3,6 +3,7 @@ package nebula import ( "crypto/rand" "encoding/json" + "fmt" "sync/atomic" "github.com/flynn/noise" @@ -17,50 +18,54 @@ type ConnectionState struct { eKey *NebulaCipherState dKey *NebulaCipherState H *noise.HandshakeState - myCert *cert.NebulaCertificate - peerCert *cert.NebulaCertificate + myCert cert.Certificate + peerCert *cert.CachedCertificate initiator bool messageCounter atomic.Uint64 window *Bits writeLock syncMutex } -func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { +func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { var dhFunc noise.DHFunc - switch certState.Certificate.Details.Curve { + switch crt.Curve() { case cert.Curve_CURVE25519: dhFunc = noise.DH25519 case cert.Curve_P256: - dhFunc = noiseutil.DHP256 + if cs.pkcs11Backed { + dhFunc = noiseutil.DHP256PKCS11 + } else { + dhFunc = noiseutil.DHP256 + } default: - l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve) - return nil + return nil, fmt.Errorf("invalid curve: %s", crt.Curve()) } - var cs noise.CipherSuite - if cipher == "chachapoly" { - cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) + var ncs noise.CipherSuite + if cs.cipher == "chachapoly" { + ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) } else { - cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) + ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) } - static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey} + static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()} b := NewBits(ReplayWindow) - // Clear out bit 0, we never transmit it and we don't want it showing as packet loss + // Clear out bit 0, we never transmit it, and we don't want it showing as packet loss b.Update(l, 0) hs, err := noise.NewHandshakeState(noise.Config{ - CipherSuite: cs, - Random: rand.Reader, - Pattern: pattern, - Initiator: initiator, - StaticKeypair: static, - PresharedKey: psk, - PresharedKeyPlacement: pskStage, + CipherSuite: ncs, + Random: rand.Reader, + Pattern: pattern, + Initiator: initiator, + StaticKeypair: static, + //NOTE: These should come from CertState (pki.go) when we finally implement it + PresharedKey: []byte{}, + PresharedKeyPlacement: 0, }) if err != nil { - return nil + return nil, fmt.Errorf("NewConnectionState: %s", err) } // The queue and ready params prevent a counter race that would happen when @@ -69,11 +74,14 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i H: hs, initiator: initiator, window: b, - myCert: certState.Certificate, + myCert: crt, + writeLock: newSyncMutex("connection-state-write"), } + // always start the counter from 2, as packet 1 and packet 2 are handshake packets. + ci.messageCounter.Add(2) - return ci + return ci, nil } func (cs *ConnectionState) MarshalJSON() ([]byte, error) { @@ -83,3 +91,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) { "message_counter": cs.messageCounter.Load(), }) } + +func (cs *ConnectionState) Curve() cert.Curve { + return cs.myCert.Curve() +} diff --git a/control.go b/control.go index c227b20..20dd7fe 100644 --- a/control.go +++ b/control.go @@ -2,7 +2,7 @@ package nebula import ( "context" - "net" + "net/netip" "os" "os/signal" "syscall" @@ -10,9 +10,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" - "github.com/slackhq/nebula/udp" ) // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching @@ -21,10 +19,10 @@ import ( type controlEach func(h *HostInfo) type controlHostLister interface { - QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo + QueryVpnAddr(vpnAddr netip.Addr) *HostInfo ForEachIndex(each controlEach) - ForEachVpnIp(each controlEach) - GetPreferredRanges() []*net.IPNet + ForEachVpnAddr(each controlEach) + GetPreferredRanges() []netip.Prefix } type Control struct { @@ -39,15 +37,15 @@ type Control struct { } type ControlHostInfo struct { - VpnIp net.IP `json:"vpnIp"` - LocalIndex uint32 `json:"localIndex"` - RemoteIndex uint32 `json:"remoteIndex"` - RemoteAddrs []*udp.Addr `json:"remoteAddrs"` - Cert *cert.NebulaCertificate `json:"cert"` - MessageCounter uint64 `json:"messageCounter"` - CurrentRemote *udp.Addr `json:"currentRemote"` - CurrentRelaysToMe []iputil.VpnIp `json:"currentRelaysToMe"` - CurrentRelaysThroughMe []iputil.VpnIp `json:"currentRelaysThroughMe"` + VpnAddrs []netip.Addr `json:"vpnAddrs"` + LocalIndex uint32 `json:"localIndex"` + RemoteIndex uint32 `json:"remoteIndex"` + RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` + Cert cert.Certificate `json:"cert"` + MessageCounter uint64 `json:"messageCounter"` + CurrentRemote netip.AddrPort `json:"currentRemote"` + CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"` + CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` } // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() @@ -131,8 +129,48 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { } } -// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found -func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo { +// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found +func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate { + _, found := c.f.myVpnAddrsTable.Lookup(vpnIp) + if found { + // Only returning the default certificate since its impossible + // for any other host but ourselves to have more than 1 + return c.f.pki.getCertState().GetDefaultCertificate().Copy() + } + hi := c.f.hostMap.QueryVpnAddr(vpnIp) + if hi == nil { + return nil + } + return hi.GetCert().Certificate.Copy() +} + +// CreateTunnel creates a new tunnel to the given vpn ip. +func (c *Control) CreateTunnel(vpnIp netip.Addr) { + c.f.handshakeManager.StartHandshake(vpnIp, nil) +} + +// PrintTunnel creates a new tunnel to the given vpn ip. +func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo { + hi := c.f.hostMap.QueryVpnAddr(vpnIp) + if hi == nil { + return nil + } + chi := copyHostInfo(hi, c.f.hostMap.GetPreferredRanges()) + return &chi +} + +// QueryLighthouse queries the lighthouse. +func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap { + hi := c.f.lightHouse.Query(vpnIp) + if hi == nil { + return nil + } + return hi.CopyCache() +} + +// GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found +// Caller should take care to Unmap() any 4in6 addresses prior to calling. +func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo { var hl controlHostLister if pending { hl = c.f.handshakeManager @@ -140,7 +178,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH hl = c.f.hostMap } - h := hl.QueryVpnIp(vpnIp) + h := hl.QueryVpnAddr(vpnAddr) if h == nil { return nil } @@ -150,20 +188,22 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH } // SetRemoteForTunnel forces a tunnel to use a specific remote -func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo { - hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) +// Caller should take care to Unmap() any 4in6 addresses prior to calling. +func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo { + hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp) if hostInfo == nil { return nil } - hostInfo.SetRemote(addr.Copy()) + hostInfo.SetRemote(addr) ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges()) return &ch } // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. -func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool { - hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) +// Caller should take care to Unmap() any 4in6 addresses prior to calling. +func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { + hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp) if hostInfo == nil { return false } @@ -187,29 +227,24 @@ func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool { // CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels // the int returned is a count of tunnels closed func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { - //TODO: this is probably better as a function in ConnectionManager or HostMap directly - lighthouses := c.f.lightHouse.GetLighthouses() - shutdown := func(h *HostInfo) { - if excludeLighthouses { - if _, ok := lighthouses[h.vpnIp]; ok { - return - } + if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) { + return } c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.closeTunnel(h) - c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote). + c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote). Debug("Sending close tunnel message") closed++ } // Learn which hosts are being used as relays, so we can shut them down last. - relayingHosts := map[iputil.VpnIp]*HostInfo{} + relayingHosts := map[netip.Addr]*HostInfo{} // Grab the hostMap lock to access the Relays map c.f.hostMap.Lock() for _, relayingHost := range c.f.hostMap.Relays { - relayingHosts[relayingHost.vpnIp] = relayingHost + relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost } c.f.hostMap.Unlock() @@ -217,7 +252,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { // Grab the hostMap lock to access the Hosts map c.f.hostMap.Lock() for _, relayHost := range c.f.hostMap.Indexes { - if _, ok := relayingHosts[relayHost.vpnIp]; !ok { + if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok { hostInfos = append(hostInfos, relayHost) } } @@ -236,15 +271,19 @@ func (c *Control) Device() overlay.Device { return c.f.inside } -func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { - +func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { chi := ControlHostInfo{ - VpnIp: h.vpnIp.ToIP(), + VpnAddrs: make([]netip.Addr, len(h.vpnAddrs)), LocalIndex: h.localIndexId, RemoteIndex: h.remoteIndexId, RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), CurrentRelaysToMe: h.relayState.CopyRelayIps(), CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(), + CurrentRemote: h.remote, + } + + for i, a := range h.vpnAddrs { + chi.VpnAddrs[i] = a } if h.ConnectionState != nil { @@ -252,11 +291,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { } if c := h.GetCert(); c != nil { - chi.Cert = c.Copy() - } - - if h.remote != nil { - chi.CurrentRemote = h.remote.Copy() + chi.Cert = c.Certificate.Copy() } return chi @@ -265,7 +300,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { func listHostMapHosts(hl controlHostLister) []ControlHostInfo { hosts := make([]ControlHostInfo, 0) pr := hl.GetPreferredRanges() - hl.ForEachVpnIp(func(hostinfo *HostInfo) { + hl.ForEachVpnAddr(func(hostinfo *HostInfo) { hosts = append(hosts, copyHostInfo(hostinfo, pr)) }) return hosts diff --git a/control_test.go b/control_test.go index c64a3a4..e400992 100644 --- a/control_test.go +++ b/control_test.go @@ -2,72 +2,66 @@ package nebula import ( "net" + "net/netip" "reflect" "testing" - "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" - "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" ) func TestControl_GetHostInfoByVpnIp(t *testing.T) { + //TODO: CERT-V2 with multiple certificate versions we have a problem with this test + // Some certs versions have different characteristics and each version implements their own Copy() func + // which means this is not a good place to test for exposing memory l := test.NewLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller - hm := newHostMap(l, &net.IPNet{}) - hm.preferredRanges.Store(&[]*net.IPNet{}) + hm := newHostMap(l) + hm.preferredRanges.Store(&[]netip.Prefix{}) + + remote1 := netip.MustParseAddrPort("0.0.0.100:4444") + remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444") - remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444) - remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444) ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), + IP: remote1.Addr().AsSlice(), Mask: net.IPMask{255, 255, 255, 0}, } ipNet2 := net.IPNet{ - IP: net.ParseIP("1:2:3:4:5:6:7:8"), + IP: remote2.Addr().AsSlice(), Mask: net.IPMask{255, 255, 255, 0}, } - crt := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test", - Ips: []*net.IPNet{&ipNet}, - Subnets: []*net.IPNet{}, - Groups: []string{"default-group"}, - NotBefore: time.Unix(1, 0), - NotAfter: time.Unix(2, 0), - PublicKey: []byte{5, 6, 7, 8}, - IsCA: false, - Issuer: "the-issuer", - InvertedGroups: map[string]struct{}{"default-group": {}}, - }, - Signature: []byte{1, 2, 1, 2, 1, 3}, - } + remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil) + remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port())) + remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port())) - remotes := NewRemoteList(nil) - remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port))) - remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port))) + vpnIp, ok := netip.AddrFromSlice(ipNet.IP) + assert.True(t, ok) + + crt := &dummyCert{} hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, ConnectionState: &ConnectionState{ - peerCert: crt, + peerCert: &cert.CachedCertificate{Certificate: crt}, }, remoteIndexId: 200, localIndexId: 201, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnAddrs: []netip.Addr{vpnIp}, relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) + vpnIp2, ok := netip.AddrFromSlice(ipNet2.IP) + assert.True(t, ok) + hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, @@ -76,11 +70,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - vpnIp: iputil.Ip2VpnIp(ipNet2.IP), + vpnAddrs: []netip.Addr{vpnIp2}, relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) @@ -91,31 +85,32 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l: logrus.New(), } - thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false) + thi := c.GetHostInfoByVpnAddr(vpnIp, false) expectedInfo := ControlHostInfo{ - VpnIp: net.IPv4(1, 2, 3, 4).To4(), + VpnAddrs: []netip.Addr{vpnIp}, LocalIndex: 201, RemoteIndex: 200, - RemoteAddrs: []*udp.Addr{remote2, remote1}, + RemoteAddrs: []netip.AddrPort{remote2, remote1}, Cert: crt.Copy(), MessageCounter: 0, - CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444), - CurrentRelaysToMe: []iputil.VpnIp{}, - CurrentRelaysThroughMe: []iputil.VpnIp{}, + CurrentRemote: remote1, + CurrentRelaysToMe: []netip.Addr{}, + CurrentRelaysThroughMe: []netip.Addr{}, } // Make sure we don't have any unexpected fields - assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) + assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) + assert.Equal(t, &expectedInfo, thi) test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet assert.NotPanics(t, func() { - thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false) + thi = c.GetHostInfoByVpnAddr(vpnIp2, false) }) } -func assertFields(t *testing.T, expected []string, actualStruct interface{}) { +func assertFields(t *testing.T, expected []string, actualStruct any) { val := reflect.ValueOf(actualStruct).Elem() fields := make([]string, val.NumField()) for i := 0; i < val.NumField(); i++ { diff --git a/control_tester.go b/control_tester.go index b786ba3..451dac5 100644 --- a/control_tester.go +++ b/control_tester.go @@ -4,14 +4,11 @@ package nebula import ( - "net" - - "github.com/slackhq/nebula/cert" + "net/netip" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" ) @@ -50,37 +47,30 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, // InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp // This is necessary if you did not configure static hosts or are not running a lighthouse -func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) { +func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp)) + remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp}) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() - iVpnIp := iputil.Ip2VpnIp(vpnIp) - if v4 := toAddr.IP.To4(); v4 != nil { - remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port))) + if toAddr.Addr().Is4() { + remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port())) } else { - remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))) + remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port())) } } // InjectRelays will push relayVpnIps into the local lighthouse cache for the vpnIp // This is necessary to inform an initiator of possible relays for communicating with a responder -func (c *Control) InjectRelays(vpnIp net.IP, relayVpnIps []net.IP) { +func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp)) + remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp}) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() - iVpnIp := iputil.Ip2VpnIp(vpnIp) - uVpnIp := []uint32{} - for _, rVPnIp := range relayVpnIps { - uVpnIp = append(uVpnIp, uint32(iputil.Ip2VpnIp(rVPnIp))) - } - - remoteList.unlockedSetRelay(iVpnIp, iVpnIp, uVpnIp) + remoteList.unlockedSetRelay(vpnIp, relayVpnIps) } // GetFromTun will pull a packet off the tun side of nebula @@ -107,20 +97,42 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) { } // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol -func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16, data []byte) { - ip := layers.IPv4{ - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, - SrcIP: c.f.inside.Cidr().IP, - DstIP: toIp, +func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) { + serialize := make([]gopacket.SerializableLayer, 0) + var netLayer gopacket.NetworkLayer + if toAddr.Is6() { + if !fromAddr.Is6() { + panic("Cant send ipv6 to ipv4") + } + ip := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolUDP, + SrcIP: fromAddr.Unmap().AsSlice(), + DstIP: toAddr.Unmap().AsSlice(), + } + serialize = append(serialize, ip) + netLayer = ip + } else { + if !fromAddr.Is4() { + panic("Cant send ipv4 to ipv6") + } + + ip := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + SrcIP: fromAddr.Unmap().AsSlice(), + DstIP: toAddr.Unmap().AsSlice(), + } + serialize = append(serialize, ip) + netLayer = ip } udp := layers.UDP{ SrcPort: layers.UDPPort(fromPort), DstPort: layers.UDPPort(toPort), } - err := udp.SetNetworkLayerForChecksum(&ip) + err := udp.SetNetworkLayerForChecksum(netLayer) if err != nil { panic(err) } @@ -130,7 +142,9 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16 ComputeChecksums: true, FixLengths: true, } - err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data)) + + serialize = append(serialize, &udp, gopacket.Payload(data)) + err = gopacket.SerializeLayers(buffer, opt, serialize...) if err != nil { panic(err) } @@ -138,16 +152,16 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16 c.f.inside.(*overlay.TestTun).Send(buffer.Bytes()) } -func (c *Control) GetVpnIp() iputil.VpnIp { - return c.f.myVpnIp +func (c *Control) GetVpnAddrs() []netip.Addr { + return c.f.myVpnAddrs } -func (c *Control) GetUDPAddr() string { - return c.f.outside.(*udp.TesterConn).Addr.String() +func (c *Control) GetUDPAddr() netip.AddrPort { + return c.f.outside.(*udp.TesterConn).Addr } -func (c *Control) KillPendingTunnel(vpnIp net.IP) bool { - hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp)) +func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool { + hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp) if hostinfo == nil { return false } @@ -160,10 +174,10 @@ func (c *Control) GetHostmap() *HostMap { return c.f.hostMap } -func (c *Control) GetCert() *cert.NebulaCertificate { - return c.f.pki.GetCertState().Certificate +func (c *Control) GetCertState() *CertState { + return c.f.pki.getCertState() } -func (c *Control) ReHandshake(vpnIp iputil.VpnIp) { +func (c *Control) ReHandshake(vpnIp netip.Addr) { c.f.handshakeManager.StartHandshake(vpnIp, nil) } diff --git a/dns_server.go b/dns_server.go index bc25adc..1194fdf 100644 --- a/dns_server.go +++ b/dns_server.go @@ -3,13 +3,14 @@ package nebula import ( "fmt" "net" + "net/netip" "strconv" "strings" + "github.com/gaissmai/bart" "github.com/miekg/dns" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) // This whole thing should be rewritten to use context @@ -20,74 +21,121 @@ var dnsAddr string type dnsRecords struct { syncRWMutex - dnsMap map[string]string - hostMap *HostMap + l *logrus.Logger + dnsMap4 map[string]netip.Addr + dnsMap6 map[string]netip.Addr + hostMap *HostMap + myVpnAddrsTable *bart.Table[struct{}] } -func newDnsRecords(hostMap *HostMap) *dnsRecords { +func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords { return &dnsRecords{ - syncRWMutex: newSyncRWMutex("dns-records"), - dnsMap: make(map[string]string), - hostMap: hostMap, + syncRWMutex: newSyncRWMutex("dns-records"), + l: l, + dnsMap4: make(map[string]netip.Addr), + dnsMap6: make(map[string]netip.Addr), + hostMap: hostMap, + myVpnAddrsTable: cs.myVpnAddrsTable, } } -func (d *dnsRecords) Query(data string) string { +func (d *dnsRecords) Query(q uint16, data string) netip.Addr { + data = strings.ToLower(data) d.RLock() defer d.RUnlock() - if r, ok := d.dnsMap[strings.ToLower(data)]; ok { - return r + switch q { + case dns.TypeA: + if r, ok := d.dnsMap4[data]; ok { + return r + } + case dns.TypeAAAA: + if r, ok := d.dnsMap6[data]; ok { + return r + } } - return "" + + return netip.Addr{} } func (d *dnsRecords) QueryCert(data string) string { - ip := net.ParseIP(data[:len(data)-1]) - if ip == nil { + ip, err := netip.ParseAddr(data[:len(data)-1]) + if err != nil { return "" } - iip := iputil.Ip2VpnIp(ip) - hostinfo := d.hostMap.QueryVpnIp(iip) + + hostinfo := d.hostMap.QueryVpnAddr(ip) if hostinfo == nil { return "" } + q := hostinfo.GetCert() if q == nil { return "" } - cert := q.Details - c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer) - return c + + b, err := q.Certificate.MarshalJSON() + if err != nil { + return "" + } + return string(b) } -func (d *dnsRecords) Add(host, data string) { +// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host` +func (d *dnsRecords) Add(host string, addresses []netip.Addr) { + host = strings.ToLower(host) d.Lock() defer d.Unlock() - d.dnsMap[strings.ToLower(host)] = data + haveV4 := false + haveV6 := false + for _, addr := range addresses { + if addr.Is4() && !haveV4 { + d.dnsMap4[host] = addr + haveV4 = true + } else if addr.Is6() && !haveV6 { + d.dnsMap6[host] = addr + haveV6 = true + } + if haveV4 && haveV6 { + break + } + } } -func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { +func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { + a, _, _ := net.SplitHostPort(addr) + b, err := netip.ParseAddr(a) + if err != nil { + return false + } + + if b.IsLoopback() { + return true + } + + _, found := d.myVpnAddrsTable.Lookup(b) + return found //if we found it in this table, it's good +} + +func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { for _, q := range m.Question { switch q.Qtype { - case dns.TypeA: - l.Debugf("Query for A %s", q.Name) - ip := dnsR.Query(q.Name) - if ip != "" { - rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip)) + case dns.TypeA, dns.TypeAAAA: + qType := dns.TypeToString[q.Qtype] + d.l.Debugf("Query for %s %s", qType, q.Name) + ip := d.Query(q.Qtype, q.Name) + if ip.IsValid() { + rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip)) if err == nil { m.Answer = append(m.Answer, rr) } } case dns.TypeTXT: - a, _, _ := net.SplitHostPort(w.RemoteAddr().String()) - b := net.ParseIP(a) - // We don't answer these queries from non nebula nodes or localhost - //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR) - if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" { + // We only answer these queries from nebula nodes or localhost + if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) { return } - l.Debugf("Query for TXT %s", q.Name) - ip := dnsR.QueryCert(q.Name) + d.l.Debugf("Query for TXT %s", q.Name) + ip := d.QueryCert(q.Name) if ip != "" { rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) if err == nil { @@ -102,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { } } -func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) { +func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) m.Compress = false switch r.Opcode { case dns.OpcodeQuery: - parseQuery(l, m, w) + d.parseQuery(m, w) } w.WriteMsg(m) } -func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() { - dnsR = newDnsRecords(hostMap) +func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() { + dnsR = newDnsRecords(l, cs, hostMap) // attach request handler func - dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { - handleDnsRequest(l, w, r) - }) + dns.HandleFunc(".", dnsR.handleDnsRequest) c.RegisterReloadCallback(func(c *config.C) { reloadDns(l, c) diff --git a/dns_server_test.go b/dns_server_test.go index 69f6ae8..356e589 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -1,46 +1,61 @@ package nebula import ( + "net/netip" "testing" "github.com/miekg/dns" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/stretchr/testify/assert" ) func TestParsequery(t *testing.T) { - //TODO: This test is basically pointless + l := logrus.New() hostMap := &HostMap{} - ds := newDnsRecords(hostMap) - ds.Add("test.com.com", "1.2.3.4") + ds := newDnsRecords(l, &CertState{}, hostMap) + addrs := []netip.Addr{ + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("1.2.3.5"), + netip.MustParseAddr("fd01::24"), + netip.MustParseAddr("fd01::25"), + } + ds.Add("test.com.com", addrs) - m := new(dns.Msg) + m := &dns.Msg{} m.SetQuestion("test.com.com", dns.TypeA) + ds.parseQuery(m, nil) + assert.NotNil(t, m.Answer) + assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String()) - //parseQuery(m) + m = &dns.Msg{} + m.SetQuestion("test.com.com", dns.TypeAAAA) + ds.parseQuery(m, nil) + assert.NotNil(t, m.Answer) + assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String()) } func Test_getDnsServerAddr(t *testing.T) { c := config.NewC(nil) - c.Settings["lighthouse"] = map[interface{}]interface{}{ - "dns": map[interface{}]interface{}{ + c.Settings["lighthouse"] = map[string]any{ + "dns": map[string]any{ "host": "0.0.0.0", "port": "1", }, } assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c)) - c.Settings["lighthouse"] = map[interface{}]interface{}{ - "dns": map[interface{}]interface{}{ + c.Settings["lighthouse"] = map[string]any{ + "dns": map[string]any{ "host": "::", "port": "1", }, } assert.Equal(t, "[::]:1", getDnsServerAddr(c)) - c.Settings["lighthouse"] = map[interface{}]interface{}{ - "dns": map[interface{}]interface{}{ + c.Settings["lighthouse"] = map[string]any{ + "dns": map[string]any{ "host": "[::]", "port": "1", }, @@ -48,8 +63,8 @@ func Test_getDnsServerAddr(t *testing.T) { assert.Equal(t, "[::]:1", getDnsServerAddr(c)) // Make sure whitespace doesn't mess us up - c.Settings["lighthouse"] = map[interface{}]interface{}{ - "dns": map[interface{}]interface{}{ + c.Settings["lighthouse"] = map[string]any{ + "dns": map[string]any{ "host": "[::] ", "port": "1", }, diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 59f1d0e..bc080ce 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -4,28 +4,32 @@ package e2e import ( - "fmt" - "net" + "net/netip" + "slices" "testing" "time" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" - "gopkg.in/yaml.v2" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) func BenchmarkHotPath(b *testing.B) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Start the servers myControl.Start() @@ -35,7 +39,7 @@ func BenchmarkHotPath(b *testing.B) { r.CancelFlowLogs() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) _ = r.RouteForAllUntilTxTun(theirControl) } @@ -44,19 +48,19 @@ func BenchmarkHotPath(b *testing.B) { } func TestGoodHandshake(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Start the servers myControl.Start() theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -77,38 +81,31 @@ func TestGoodHandshake(t *testing.T) { myControl.WaitForType(1, 0, theirControl) t.Log("Make sure our host infos are correct") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) t.Log("Get that cached packet and make sure it looks right") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Do a bidirectional tunnel test") r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() theirControl.Stop() - //TODO: assert hostmaps } func TestWrongResponderHandshake(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - // The IPs here are chosen on purpose: - // The current remote handling will sort by preference, public, and then lexically. - // So we need them to have a higher address than evil (we could apply a preference though) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil) - evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil) - - // Add their real udp addr, which should be tried after evil. - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil) // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, evilUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl, evilControl) @@ -119,41 +116,139 @@ func TestWrongResponderHandshake(t *testing.T) { theirControl.Start() evilControl.Start() - t.Log("Start the handshake process, we will route until we see our cached packet get sent to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + t.Log("Start the handshake process, we will route until we see the evil tunnel closed") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + h := &header.H{} r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { - h := &header.H{} err := h.Parse(p.Data) if err != nil { panic(err) } - if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 { + if h.Type == header.CloseTunnel && p.To == evilUdpAddr { return router.RouteAndExit } return router.KeepRouting }) - //TODO: Assert pending hostmap - I should have a correct hostinfo for them now + t.Log("Evil tunnel is closed, inject the correct udp addr for them") + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true) + assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr) + + t.Log("Route until we see the cached packet") + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + err := h.Parse(p.Data) + if err != nil { + panic(err) + } + + if p.To == theirUdpAddr && h.Type == 1 { + return router.RouteAndExit + } + + return router.KeepRouting + }) t.Log("My cached packet should be received by them") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Test the tunnel with them") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Flush all packets from all controllers") r.FlushAll() t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), true), "My pending hostmap should not contain evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), false), "My main hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil") + + r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl) + t.Log("Success!") + myControl.Stop() + theirControl.Stop() +} + +func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil) + o := m{ + "static_host_map": m{ + theirVpnIpNet[0].Addr().String(): []string{evilUdpAddr.String()}, + }, + } + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", o) + + // Put the evil udp addr in for their vpn addr, this is a case of a remote at a static entry changing its vpn addr. + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, theirControl, evilControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + theirControl.Start() + evilControl.Start() + + t.Log("Start the handshake process, we will route until we see the evil tunnel closed") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + h := &header.H{} + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + err := h.Parse(p.Data) + if err != nil { + panic(err) + } + + if h.Type == header.CloseTunnel && p.To == evilUdpAddr { + return router.RouteAndExit + } + + return router.KeepRouting + }) + + t.Log("Evil tunnel is closed, inject the correct udp addr for them") + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true) + assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr) + + t.Log("Route until we see the cached packet") + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + err := h.Parse(p.Data) + if err != nil { + panic(err) + } + + if p.To == theirUdpAddr && h.Type == 1 { + return router.RouteAndExit + } + + return router.KeepRouting + }) + + t.Log("My cached packet should be received by them") + myCachedPacket := theirControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + + t.Log("Test the tunnel with them") + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Flush all packets from all controllers") + r.FlushAll() + + t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil") //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete - //TODO: assert hostmaps for everyone r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl) t.Log("Success!") myControl.Stop() @@ -164,13 +259,13 @@ func TestStage1Race(t *testing.T) { // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow // But will eventually collapse down to a single tunnel - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -181,8 +276,8 @@ func TestStage1Race(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake to start on both me and them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) t.Log("Get both stage 1 handshake packets") myHsForThem := myControl.GetFromUDP(true) @@ -194,14 +289,14 @@ func TestStage1Race(t *testing.T) { r.Log("Route until they receive a message packet") myCachedPacket := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Their cached packet should be received by me") theirCachedPacket := r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) r.Log("Do a bidirectional tunnel test") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myHostmapHosts := myControl.ListHostmapHosts(false) myHostmapIndexes := myControl.ListHostmapIndexes(false) @@ -219,7 +314,7 @@ func TestStage1Race(t *testing.T) { r.Log("Spin until connection manager tears down a tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } @@ -241,13 +336,13 @@ func TestStage1Race(t *testing.T) { } func TestUncleanShutdownRaceLoser(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -258,28 +353,28 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Nuke my hostmap") myHostmap := myControl.GetHostmap() - myHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{} + myHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{} myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me again")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again")) p = r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(theirControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) if len(theirControl.GetHostmap().Indexes) < start { break } @@ -290,13 +385,13 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { } func TestUncleanShutdownRaceWinner(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -307,30 +402,30 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, theirControl) r.Log("Nuke my hostmap") theirHostmap := theirControl.GetHostmap() - theirHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{} + theirHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{} theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them again")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again")) p = r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Derp hostmaps", myControl, theirControl) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(myControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) if len(myControl.GetHostmap().Indexes) < start { break } @@ -341,15 +436,15 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { } func TestRelays(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -361,31 +456,161 @@ func TestRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) - //TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it +} + +func TestReestablishRelays(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger a handshake from me to them via the relay") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + + t.Log("Ensure packet traversal from them to me via the relay") + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + + p = r.RouteForAllUntilTxTun(myControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from them"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) + + // If we break the relay's connection to 'them', 'me' needs to detect and recover the connection + r.Log("Close the tunnel") + relayControl.CloseTunnel(theirVpnIpNet[0].Addr(), true) + + start := len(myControl.GetHostmap().Indexes) + curIndexes := len(myControl.GetHostmap().Indexes) + for curIndexes >= start { + curIndexes = len(myControl.GetHostmap().Indexes) + r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail")) + + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + return router.RouteAndExit + }) + time.Sleep(2 * time.Second) + } + r.Log("Dead index went away. Woot!") + r.RenderHostmaps("Me removed hostinfo", myControl, relayControl, theirControl) + // Next packet should re-establish a relayed connection and work just great. + + t.Logf("Assert the tunnel...") + for { + t.Log("RouteForAllUntilTxTun") + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + p = r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + packet := gopacket.NewPacket(p, layers.LayerTypeIPv4, gopacket.Lazy) + v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if slices.Compare(v4.SrcIP, myVpnIpNet[0].Addr().AsSlice()) != 0 { + t.Logf("SrcIP is unexpected...this is not the packet I'm looking for. Keep looking") + continue + } + if slices.Compare(v4.DstIP, theirVpnIpNet[0].Addr().AsSlice()) != 0 { + t.Logf("DstIP is unexpected...this is not the packet I'm looking for. Keep looking") + continue + } + + udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + if udp == nil { + t.Log("Not a UDP packet. This is not the packet I'm looking for. Keep looking") + continue + } + data := packet.ApplicationLayer() + if data == nil { + t.Log("No data found in packet. This is not the packet I'm looking for. Keep looking.") + continue + } + if string(data.Payload()) != "Hi from me" { + t.Logf("Unexpected payload: '%v', keep looking", string(data.Payload())) + continue + } + t.Log("I found my lost packet. I am so happy.") + break + } + t.Log("Assert the tunnel works the other way, too") + for { + t.Log("RouteForAllUntilTxTun") + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + + p = r.RouteForAllUntilTxTun(myControl) + r.Log("Assert the tunnel works") + packet := gopacket.NewPacket(p, layers.LayerTypeIPv4, gopacket.Lazy) + v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if slices.Compare(v4.DstIP, myVpnIpNet[0].Addr().AsSlice()) != 0 { + t.Logf("Dst is unexpected...this is not the packet I'm looking for. Keep looking") + continue + } + if slices.Compare(v4.SrcIP, theirVpnIpNet[0].Addr().AsSlice()) != 0 { + t.Logf("SrcIP is unexpected...this is not the packet I'm looking for. Keep looking") + continue + } + + udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + if udp == nil { + t.Log("Not a UDP packet. This is not the packet I'm looking for. Keep looking") + continue + } + data := packet.ApplicationLayer() + if data == nil { + t.Log("No data found in packet. This is not the packet I'm looking for. Keep looking.") + continue + } + if string(data.Payload()) != "Hi from them" { + t.Logf("Unexpected payload: '%v', keep looking", string(data.Payload())) + continue + } + t.Log("I found my lost packet. I am so happy.") + break + } + r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) + } func TestStage1RaceRelays(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -397,14 +622,14 @@ func TestStage1RaceRelays(t *testing.T) { theirControl.Start() r.Log("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) r.Log("Wait for a packet from them to me") p := r.RouteForAllUntilTxTun(myControl) @@ -415,27 +640,25 @@ func TestStage1RaceRelays(t *testing.T) { myControl.Stop() theirControl.Stop() relayControl.Stop() - // - ////TODO: assert hostmaps } func TestStage1RaceRelays2(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) l := NewTestLogger() // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -448,16 +671,16 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Get a tunnel between me and relay") l.Info("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") l.Info("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") l.Info("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) @@ -470,7 +693,7 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) t.Log("Wait until we remove extra tunnels") l.Info("Wait until we remove extra tunnels") @@ -490,7 +713,7 @@ func TestStage1RaceRelays2(t *testing.T) { "theirControl": len(theirControl.GetHostmap().Indexes), "relayControl": len(relayControl.GetHostmap().Indexes), }).Info("Waiting for hostinfos to be removed...") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) retries-- @@ -498,25 +721,23 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) myControl.Stop() theirControl.Stop() relayControl.Stop() - - // - ////TODO: assert hostmaps } + func TestRehandshakingRelays(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -528,19 +749,19 @@ func TestRehandshakingRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) - caB, err := ca.MarshalToPEM() + caB, err := ca.MarshalPEM() if err != nil { panic(err) } @@ -551,14 +772,14 @@ func TestRehandshakingRelays(t *testing.T) { "key": string(myNextPrivKey), } rc, err := yaml.Marshal(relayConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) relayConfig.ReloadConfigString(string(rc)) for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) - if len(c.Cert.Details.Groups) != 0 { + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) + if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") break @@ -569,9 +790,9 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) - if len(c.Cert.Details.Groups) != 0 { + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) + if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") break @@ -581,13 +802,13 @@ func TestRehandshakingRelays(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -595,7 +816,7 @@ func TestRehandshakingRelays(t *testing.T) { for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -603,7 +824,7 @@ func TestRehandshakingRelays(t *testing.T) { for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -612,15 +833,15 @@ func TestRehandshakingRelays(t *testing.T) { func TestRehandshakingRelaysPrimary(t *testing.T) { // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -632,19 +853,19 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) - caB, err := ca.MarshalToPEM() + caB, err := ca.MarshalPEM() if err != nil { panic(err) } @@ -655,14 +876,14 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { "key": string(myNextPrivKey), } rc, err := yaml.Marshal(relayConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) relayConfig.ReloadConfigString(string(rc)) for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) - if len(c.Cert.Details.Groups) != 0 { + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) + if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") break @@ -673,9 +894,9 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) - if len(c.Cert.Details.Groups) != 0 { + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) + if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") break @@ -685,13 +906,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -699,7 +920,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -707,7 +928,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -715,13 +936,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } func TestRehandshaking(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -732,14 +953,14 @@ func TestRehandshaking(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew my certificate and spin until their sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) - caB, err := ca.MarshalToPEM() + caB, err := ca.MarshalPEM() if err != nil { panic(err) } @@ -750,13 +971,13 @@ func TestRehandshaking(t *testing.T) { "key": string(myNextPrivKey), } rc, err := yaml.Marshal(myConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) myConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) - if len(c.Cert.Details.Groups) != 0 { + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + if len(c.Cert.Groups()) != 0 { // We have a new certificate now break } @@ -764,37 +985,38 @@ func TestRehandshaking(t *testing.T) { time.Sleep(time.Second) } + r.Log("Got the new cert") // Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly rc, err = yaml.Marshal(theirConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) var theirNewConfig m - assert.NoError(t, yaml.Unmarshal(rc, &theirNewConfig)) - theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{}) + require.NoError(t, yaml.Unmarshal(rc, &theirNewConfig)) + theirFirewall := theirNewConfig["firewall"].(map[string]any) theirFirewall["inbound"] = []m{{ "proto": "any", "port": "any", "group": "new group", }} rc, err = yaml.Marshal(theirNewConfig) - assert.NoError(t, err) + require.NoError(t, err) theirConfig.ReloadConfigString(string(rc)) r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) - assert.Contains(t, c.Cert.Details.Groups, "new group") + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + assert.Contains(t, c.Cert.Groups(), "new group") // We should only have a single tunnel now on both sides assert.Len(t, myFinalHostmapHosts, 1) @@ -811,13 +1033,13 @@ func TestRehandshaking(t *testing.T) { func TestRehandshakingLoser(t *testing.T) { // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // Should be the one with the new certificate - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -828,18 +1050,14 @@ func TestRehandshakingLoser(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) - - tt1 := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) - tt2 := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) - fmt.Println(tt1.LocalIndex, tt2.LocalIndex) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew their certificate and spin until mine sees it") - _, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) + _, _, theirNextPrivKey, theirNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) - caB, err := ca.MarshalToPEM() + caB, err := ca.MarshalPEM() if err != nil { panic(err) } @@ -850,15 +1068,14 @@ func TestRehandshakingLoser(t *testing.T) { "key": string(theirNextPrivKey), } rc, err := yaml.Marshal(theirConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) theirConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) - theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) - _, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"] - if theirNewGroup { + if slices.Contains(theirCertInMe.Cert.Groups(), "their new group") { break } @@ -867,35 +1084,35 @@ func TestRehandshakingLoser(t *testing.T) { // Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly rc, err = yaml.Marshal(myConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) var myNewConfig m - assert.NoError(t, yaml.Unmarshal(rc, &myNewConfig)) - theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{}) + require.NoError(t, yaml.Unmarshal(rc, &myNewConfig)) + theirFirewall := myNewConfig["firewall"].(map[string]any) theirFirewall["inbound"] = []m{{ "proto": "any", "port": "any", "group": "their new group", }} rc, err = yaml.Marshal(myNewConfig) - assert.NoError(t, err) + require.NoError(t, err) myConfig.ReloadConfigString(string(rc)) r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) - assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group") + theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + assert.Contains(t, theirCertInMe.Cert.Groups(), "their new group") // We should only have a single tunnel now on both sides assert.Len(t, myFinalHostmapHosts, 1) @@ -912,13 +1129,13 @@ func TestRaceRegression(t *testing.T) { // This test forces stage 1, stage 2, stage 1 to be received by me from them // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // caused a cross-linked hostinfo - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Start the servers myControl.Start() @@ -932,8 +1149,8 @@ func TestRaceRegression(t *testing.T) { //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 t.Log("Start both handshakes") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) t.Log("Get both stage 1") myStage1ForThem := myControl.GetFromUDP(true) @@ -944,7 +1161,6 @@ func TestRaceRegression(t *testing.T) { myControl.InjectUDPPacket(theirStage1ForMe) theirControl.InjectUDPPacket(myStage1ForThem) - //TODO: ensure stage 2 t.Log("Get both stage 2") myStage2ForThem := myControl.GetFromUDP(true) theirStage2ForMe := theirControl.GetFromUDP(true) @@ -963,14 +1179,48 @@ func TestRaceRegression(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) t.Log("Make sure the tunnel still works") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myControl.Stop() theirControl.Stop() } -//TODO: test -// Race winner renews and handshakes -// Race loser renews and handshakes -// Does race winner repin the cert to old? -//TODO: add a test with many lies +func TestV2NonPrimaryWithLighthouse(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}}) + + o := m{ + "static_host_map": m{ + lhVpnIpNet[1].Addr().String(): []string{lhUdpAddr.String()}, + }, + "lighthouse": m{ + "hosts": []string{lhVpnIpNet[1].Addr().String()}, + "local_allow_list": m{ + // Try and block our lighthouse updates from using the actual addresses assigned to this computer + // If we start discovering addresses the test router doesn't know about then test traffic cant flow + "10.0.0.0/24": true, + "::/0": false, + }, + }, + } + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o) + theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, lhControl, myControl, theirControl) + defer r.RenderFlow() + + // Start the servers + lhControl.Start() + myControl.Start() + theirControl.Start() + + t.Log("Stand up an ipv6 tunnel between me and them") + assert.True(t, myVpnIpNet[1].Addr().Is6()) + assert.True(t, theirVpnIpNet[1].Addr().Is6()) + assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r) + + lhControl.Stop() + myControl.Stop() + theirControl.Stop() +} diff --git a/e2e/helpers.go b/e2e/helpers.go deleted file mode 100644 index 13146ab..0000000 --- a/e2e/helpers.go +++ /dev/null @@ -1,118 +0,0 @@ -package e2e - -import ( - "crypto/rand" - "io" - "net" - "time" - - "github.com/slackhq/nebula/cert" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/ed25519" -) - -// NewTestCaCert will generate a CA cert -func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { - pub, priv, err := ed25519.GenerateKey(rand.Reader) - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - nc := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test ca", - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: true, - InvertedGroups: make(map[string]struct{}), - }, - } - - if len(ips) > 0 { - nc.Details.Ips = ips - } - - if len(subnets) > 0 { - nc.Details.Subnets = subnets - } - - if len(groups) > 0 { - nc.Details.Groups = groups - } - - err = nc.Sign(cert.Curve_CURVE25519, priv) - if err != nil { - panic(err) - } - - pem, err := nc.MarshalToPEM() - if err != nil { - panic(err) - } - - return nc, pub, priv, pem -} - -// NewTestCert will generate a signed certificate with the provided details. -// Expiry times are defaulted if you do not pass them in -func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { - issuer, err := ca.Sha256Sum() - if err != nil { - panic(err) - } - - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - pub, rawPriv := x25519Keypair() - - nc := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: name, - Ips: []*net.IPNet{ip}, - Subnets: subnets, - Groups: groups, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: false, - Issuer: issuer, - InvertedGroups: make(map[string]struct{}), - }, - } - - err = nc.Sign(ca.Details.Curve, key) - if err != nil { - panic(err) - } - - pem, err := nc.MarshalToPEM() - if err != nil { - panic(err) - } - - return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem -} - -func x25519Keypair() ([]byte, []byte) { - privkey := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, privkey); err != nil { - panic(err) - } - - pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) - if err != nil { - panic(err) - } - - return pubkey, privkey -} diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index b05c84a..a63b3d0 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -6,8 +6,9 @@ package e2e import ( "fmt" "io" - "net" + "net/netip" "os" + "strings" "testing" "time" @@ -17,29 +18,47 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) -type m map[string]interface{} +type m = map[string]any // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr, *config.C) { +func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { l := NewTestLogger() - vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} - copy(vpnIpNet.IP, udpIp) - vpnIpNet.IP[1] += 128 - udpAddr := net.UDPAddr{ - IP: udpIp, - Port: 4242, + var vpnNetworks []netip.Prefix + for _, sn := range strings.Split(sVpnNetworks, ",") { + vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) + if err != nil { + panic(err) + } + vpnNetworks = append(vpnNetworks, vpnIpNet) } - _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) - caB, err := caCrt.MarshalToPEM() + if len(vpnNetworks) == 0 { + panic("no vpn networks") + } + + var udpAddr netip.AddrPort + if vpnNetworks[0].Addr().Is4() { + budpIp := vpnNetworks[0].Addr().As4() + budpIp[1] -= 128 + udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242) + } else { + budpIp := vpnNetworks[0].Addr().As16() + // beef for funsies + budpIp[2] = 190 + budpIp[3] = 239 + udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) + } + _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{}) + + caB, err := caCrt.MarshalPEM() if err != nil { panic(err) } @@ -67,8 +86,8 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u // "try_interval": "1s", //}, "listen": m{ - "host": udpAddr.IP.String(), - "port": udpAddr.Port, + "host": udpAddr.Addr().String(), + "port": udpAddr.Port(), }, "logging": m{ "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name), @@ -81,11 +100,16 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u } if overrides != nil { - err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) + final := m{} + err = mergo.Merge(&final, overrides, mergo.WithAppendSlice) if err != nil { panic(err) } - mc = overrides + err = mergo.Merge(&final, mc, mergo.WithAppendSlice) + if err != nil { + panic(err) + } + mc = final } cb, err := yaml.Marshal(mc) @@ -102,7 +126,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u panic(err) } - return control, vpnIpNet, &udpAddr, c + return control, vpnNetworks, udpAddr, c } type doneCb func() @@ -123,64 +147,54 @@ func deadline(t *testing.T, seconds time.Duration) doneCb { } } -func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) { +func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { // Send a packet from them to me - controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B")) + controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")) bPacket := r.RouteForAllUntilTxTun(controlA) assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80) // And once more from me to them - controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A")) + controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A")) aPacket := r.RouteForAllUntilTxTun(controlB) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) } -func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) { +func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) { // Get both host infos - hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false) - assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA") + //TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things + hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false) + assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA") - hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false) - assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB") + hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false) + assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB") // Check that both vpn and real addr are correct - assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A") - assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B") + assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A") + assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B") - assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A") - assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B") - - assert.Equal(t, addrB.Port, int(hBinA.CurrentRemote.Port), "Host B remote port is wrong in control A") - assert.Equal(t, addrA.Port, int(hAinB.CurrentRemote.Port), "Host A remote port is wrong in control B") + assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A") + assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B") // Check that our indexes match assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index") assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index") - - //TODO: Would be nice to assert this memory - //checkIndexes := func(name string, hm *HostMap, hi *HostInfo) { - // hBbyIndex := hmA.Indexes[hBinA.localIndexId] - // assert.NotNil(t, hBbyIndex, "Could not host info by local index in %s", name) - // assert.Equal(t, &hBbyIndex, &hBinA, "%s Indexes map did not point to the right host info", name) - // - // //TODO: remote indexes are susceptible to collision - // hBbyRemoteIndex := hmA.RemoteIndexes[hBinA.remoteIndexId] - // assert.NotNil(t, hBbyIndex, "Could not host info by remote index in %s", name) - // assert.Equal(t, &hBbyRemoteIndex, &hBinA, "%s RemoteIndexes did not point to the right host info", name) - //} - // - //// Check hostmap indexes too - //checkIndexes("hmA", hmA, hBinA) - //checkIndexes("hmB", hmB, hAinB) } -func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, fromPort, toPort uint16) { - packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) - v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) - assert.NotNil(t, v4, "No ipv4 data found") +func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { + if toIp.Is6() { + assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort) + } else { + assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort) + } +} - assert.Equal(t, fromIp, v4.SrcIP, "Source ip was incorrect") - assert.Equal(t, toIp, v4.DstIP, "Dest ip was incorrect") +func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { + packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy) + v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) + assert.NotNil(t, v6, "No ipv6 data found") + + assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect") + assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect") udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) assert.NotNil(t, udp, "No udp data found") @@ -193,6 +207,33 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, from assert.Equal(t, expected, data.Payload(), "Data was incorrect") } +func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { + packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) + v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + assert.NotNil(t, v4, "No ipv4 data found") + + assert.Equal(t, fromIp.AsSlice(), []byte(v4.SrcIP), "Source ip was incorrect") + assert.Equal(t, toIp.AsSlice(), []byte(v4.DstIP), "Dest ip was incorrect") + + udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + assert.NotNil(t, udp, "No udp data found") + + assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect") + assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect") + + data := packet.ApplicationLayer() + assert.NotNil(t, data) + assert.Equal(t, expected, data.Payload(), "Data was incorrect") +} + +func getAddrs(ns []netip.Prefix) []netip.Addr { + var a []netip.Addr + for _, n := range ns { + a = append(a, n.Addr()) + } + return a +} + func NewTestLogger() *logrus.Logger { l := logrus.New() diff --git a/e2e/router/hostmap.go b/e2e/router/hostmap.go index 120be69..f2805d0 100644 --- a/e2e/router/hostmap.go +++ b/e2e/router/hostmap.go @@ -5,11 +5,11 @@ package router import ( "fmt" + "net/netip" "sort" "strings" "github.com/slackhq/nebula" - "github.com/slackhq/nebula/iputil" ) type edge struct { @@ -58,8 +58,9 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { var lines []string var globalLines []*edge - clusterName := strings.Trim(c.GetCert().Details.Name, " ") - clusterVpnIp := c.GetCert().Details.Ips[0].IP + crt := c.GetCertState().GetDefaultCertificate() + clusterName := strings.Trim(crt.Name(), " ") + clusterVpnIp := crt.Networks()[0].Addr() r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp) hm := c.GetHostmap() @@ -101,8 +102,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { for _, idx := range indexes { hi, ok := hm.Indexes[idx] if ok { - r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp()) - remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ") + r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs()) + remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ") globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())}) _ = hi } @@ -118,14 +119,14 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { return r, globalLines } -func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp { - keys := make([]iputil.VpnIp, 0, len(hosts)) +func sortedHosts(hosts map[netip.Addr]*nebula.HostInfo) []netip.Addr { + keys := make([]netip.Addr, 0, len(hosts)) for key := range hosts { keys = append(keys, key) } sort.SliceStable(keys, func(i, j int) bool { - return keys[i] > keys[j] + return keys[i].Compare(keys[j]) > 0 }) return keys diff --git a/e2e/router/router.go b/e2e/router/router.go index 730853a..5e52ed7 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -6,13 +6,12 @@ package router import ( "context" "fmt" - "net" + "net/netip" "os" "path/filepath" "reflect" + "regexp" "sort" - "strconv" - "strings" "sync" "testing" "time" @@ -21,7 +20,6 @@ import ( "github.com/google/gopacket/layers" "github.com/slackhq/nebula" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "golang.org/x/exp/maps" ) @@ -29,18 +27,18 @@ import ( type R struct { // Simple map of the ip:port registered on a control to the control // Basically a router, right? - controls map[string]*nebula.Control + controls map[netip.AddrPort]*nebula.Control // A map for inbound packets for a control that doesn't know about this address - inNat map[string]*nebula.Control + inNat map[netip.AddrPort]*nebula.Control // A last used map, if an inbound packet hit the inNat map then // all return packets should use the same last used inbound address for the outbound sender // map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver - outNat map[string]net.UDPAddr + outNat map[string]netip.AddrPort // A map of vpn ip to the nebula control it belongs to - vpnControls map[iputil.VpnIp]*nebula.Control + vpnControls map[netip.Addr]*nebula.Control ignoreFlows []ignoreFlow flow []flowEntry @@ -118,10 +116,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { } r := &R{ - controls: make(map[string]*nebula.Control), - vpnControls: make(map[iputil.VpnIp]*nebula.Control), - inNat: make(map[string]*nebula.Control), - outNat: make(map[string]net.UDPAddr), + controls: make(map[netip.AddrPort]*nebula.Control), + vpnControls: make(map[netip.Addr]*nebula.Control), + inNat: make(map[netip.AddrPort]*nebula.Control), + outNat: make(map[string]netip.AddrPort), flow: []flowEntry{}, ignoreFlows: []ignoreFlow{}, fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())), @@ -135,10 +133,13 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { for _, c := range controls { addr := c.GetUDPAddr() if _, ok := r.controls[addr]; ok { - panic("Duplicate listen address: " + addr) + panic("Duplicate listen address: " + addr.String()) + } + + for _, vpnAddr := range c.GetVpnAddrs() { + r.vpnControls[vpnAddr] = c } - r.vpnControls[c.GetVpnIp()] = c r.controls[addr] = c } @@ -165,13 +166,13 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { // It does not look at the addr attached to the instance. // If a route is used, this will behave like a NAT for the return path. // Rewriting the source ip:port to what was last sent to from the origin -func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) { +func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) { r.Lock() defer r.Unlock() - inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)) + inAddr := netip.AddrPortFrom(ip, port) if _, ok := r.inNat[inAddr]; ok { - panic("Duplicate listen address inNat: " + inAddr) + panic("Duplicate listen address inNat: " + inAddr.String()) } r.inNat[inAddr] = c } @@ -198,7 +199,7 @@ func (r *R) renderFlow() { panic(err) } - var participants = map[string]struct{}{} + var participants = map[netip.AddrPort]struct{}{} var participantsVals []string fmt.Fprintln(f, "```mermaid") @@ -215,11 +216,11 @@ func (r *R) renderFlow() { continue } participants[addr] = struct{}{} - sanAddr := strings.Replace(addr, ":", "-", 1) + sanAddr := normalizeName(addr.String()) participantsVals = append(participantsVals, sanAddr) fmt.Fprintf( f, " participant %s as Nebula: %s
UDP: %s\n", - sanAddr, e.packet.from.GetVpnIp(), sanAddr, + sanAddr, e.packet.from.GetVpnAddrs(), sanAddr, ) } @@ -252,9 +253,9 @@ func (r *R) renderFlow() { fmt.Fprintf(f, " %s%s%s: %s(%s), index %v, counter: %v\n", - strings.Replace(p.from.GetUDPAddr(), ":", "-", 1), + normalizeName(p.from.GetUDPAddr().String()), line, - strings.Replace(p.to.GetUDPAddr(), ":", "-", 1), + normalizeName(p.to.GetUDPAddr().String()), h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter, ) } @@ -269,6 +270,11 @@ func (r *R) renderFlow() { } } +func normalizeName(s string) string { + rx := regexp.MustCompile("[\\[\\]\\:]") + return rx.ReplaceAllLiteralString(s, "_") +} + // IgnoreFlow tells the router to stop recording future flows that matches the provided criteria. // messageType and subType will target nebula underlay packets while tun will target nebula overlay packets // NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered @@ -305,7 +311,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) { func (r *R) renderHostmaps(title string) { c := maps.Values(r.controls) sort.SliceStable(c, func(i, j int) bool { - return c[i].GetVpnIp() > c[j].GetVpnIp() + return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0 }) s := renderHostmaps(c...) @@ -420,13 +426,12 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) [] // Nope, lets push the sender along case p := <-udpTx: - outAddr := sender.GetUDPAddr() r.Lock() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - c := r.getControl(outAddr, inAddr, p) + a := sender.GetUDPAddr() + c := r.getControl(a, p.To, p) if c == nil { r.Unlock() - panic("No control for udp tx") + panic("No control for udp tx " + a.String()) } fp := r.unlockedInjectFlow(sender, c, p, false) c.InjectUDPPacket(p) @@ -479,13 +484,11 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte { } else { // we are a udp tx, route and continue p := rx.Interface().(*udp.Packet) - outAddr := cm[x].GetUDPAddr() - - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - c := r.getControl(outAddr, inAddr, p) + a := cm[x].GetUDPAddr() + c := r.getControl(a, p.To, p) if c == nil { r.Unlock() - panic("No control for udp tx") + panic(fmt.Sprintf("No control for udp tx %s", p.To)) } fp := r.unlockedInjectFlow(cm[x], c, p, false) c.InjectUDPPacket(p) @@ -509,12 +512,10 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { panic(err) } - outAddr := sender.GetUDPAddr() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - receiver := r.getControl(outAddr, inAddr, p) + receiver := r.getControl(sender.GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() - panic("Can't route for host: " + inAddr) + panic("Can't RouteExitFunc for host: " + p.To.String()) } e := whatDo(p, receiver) @@ -590,13 +591,13 @@ func (r *R) InjectUDPPacket(sender, receiver *nebula.Control, packet *udp.Packet // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit` // If the router doesn't have the nebula controller for that address, we panic -func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) { +func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr netip.AddrPort, finish ExitType) { if finish == KeepRouting { finish = RouteAndExit } r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType { - if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) { + if p.To == toAddr { return finish } @@ -630,13 +631,10 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { r.Lock() p := rx.Interface().(*udp.Packet) - - outAddr := cm[x].GetUDPAddr() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - receiver := r.getControl(outAddr, inAddr, p) + receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() - panic("Can't route for host: " + inAddr) + panic("Can't RouteForAllExitFunc for host: " + p.To.String()) } e := whatDo(p, receiver) @@ -697,12 +695,10 @@ func (r *R) FlushAll() { p := rx.Interface().(*udp.Packet) - outAddr := cm[x].GetUDPAddr() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - receiver := r.getControl(outAddr, inAddr, p) + receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() - panic("Can't route for host: " + inAddr) + panic("Can't FlushAll for host: " + p.To.String()) } r.Unlock() } @@ -710,28 +706,14 @@ func (r *R) FlushAll() { // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change // This is an internal router function, the caller must hold the lock -func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control { - if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok { - p.FromIp = newAddr.IP - p.FromPort = uint16(newAddr.Port) +func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control { + if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok { + p.From = newAddr } c, ok := r.inNat[toAddr] if ok { - sHost, sPort, err := net.SplitHostPort(toAddr) - if err != nil { - panic(err) - } - - port, err := strconv.Atoi(sPort) - if err != nil { - panic(err) - } - - r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{ - IP: net.ParseIP(sHost), - Port: port, - } + r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr return c } @@ -739,29 +721,42 @@ func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control { } func (r *R) formatUdpPacket(p *packet) string { - packet := gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy) - v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) - if v4 == nil { - panic("not an ipv4 packet") + var packet gopacket.Packet + var srcAddr netip.Addr + + packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy) + if packet.ErrorLayer() == nil { + v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) + if v6 == nil { + panic("not an ipv6 packet") + } + srcAddr, _ = netip.AddrFromSlice(v6.SrcIP) + } else { + packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy) + v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if v6 == nil { + panic("not an ipv6 packet") + } + srcAddr, _ = netip.AddrFromSlice(v6.SrcIP) } from := "unknown" - if c, ok := r.vpnControls[iputil.Ip2VpnIp(v4.SrcIP)]; ok { - from = c.GetUDPAddr() + if c, ok := r.vpnControls[srcAddr]; ok { + from = c.GetUDPAddr().String() } - udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) - if udp == nil { + udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + if udpLayer == nil { panic("not a udp packet") } data := packet.ApplicationLayer() return fmt.Sprintf( " %s-->>%s: src port: %v
dest port: %v
data: \"%v\"\n", - strings.Replace(from, ":", "-", 1), - strings.Replace(p.to.GetUDPAddr(), ":", "-", 1), - udp.SrcPort, - udp.DstPort, + normalizeName(from), + normalizeName(p.to.GetUDPAddr().String()), + udpLayer.SrcPort, + udpLayer.DstPort, string(data.Payload()), ) } diff --git a/examples/config.yml b/examples/config.yml index c74ffc6..534608d 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -13,6 +13,12 @@ pki: # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid. #disconnect_invalid: true + # default_version controls which certificate version is used in handshakes. + # This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`. + # Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`. + # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed. + # default_version: 1 + # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel. # The syntax is: @@ -120,8 +126,8 @@ lighthouse: # Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined, # however using port 0 will dynamically assign a port and is recommended for roaming nodes. listen: - # To listen on both any ipv4 and ipv6 use "::" - host: 0.0.0.0 + # To listen on only ipv4, use "0.0.0.0" + host: "::" port: 4242 # Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg) # default is 64, does not support reload @@ -138,6 +144,11 @@ listen: # valid values: always, never, private # This setting is reloadable. #send_recv_error: always + # The so_sock option is a Linux-specific feature that allows all outgoing Nebula packets to be tagged with a specific identifier. + # This tagging enables IP rule-based filtering. For example, it supports 0.0.0.0/0 unsafe_routes, + # allowing for more precise routing decisions based on the packet tags. Default is 0 meaning no mark is set. + # This setting is reloadable. + #so_mark: 0 # Routines is the number of thread pairs to run that consume from the tun and UDP queues. # Currently, this defaults to 1 which means we have 1 tun queue reader and 1 @@ -228,7 +239,28 @@ tun: # Unsafe routes allows you to route traffic over nebula to non-nebula nodes # Unsafe routes should be avoided unless you have hosts/services that cannot run nebula - # NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate + # Supports weighted ECMP if you define a list of gateways, this can be used for load balancing or redundancy to hosts outside of nebula + # NOTES: + # * You will only see a single gateway in the routing table if you are not on linux + # * If a gateway is not reachable through the overlay another gateway will be selected to send the traffic through, ignoring weights + # + # unsafe_routes: + # # Multiple gateways without defining a weight defaults to a weight of 1, this will balance traffic equally between the three gateways + # - route: 192.168.87.0/24 + # via: + # - gateway: 10.0.0.1 + # - gateway: 10.0.0.2 + # - gateway: 10.0.0.3 + # # Multiple gateways with a weight, this will balance traffic accordingly + # - route: 192.168.87.0/24 + # via: + # - gateway: 10.0.0.1 + # weight: 10 + # - gateway: 10.0.0.2 + # weight: 5 + # + # NOTE: The nebula certificate of the "via" node(s) *MUST* have the "route" defined as a subnet in its certificate + # `via`: single node or list of gateways to use for this route # `mtu`: will default to tun mtu if this option is not specified # `metric`: will default to 0 if this option is not specified # `install`: will default to true, controls whether this route is installed in the systems routing table. @@ -244,7 +276,6 @@ tun: # in nebula configuration files. Default false, not reloadable. #use_system_route_table: false -# TODO # Configure logging level logging: # panic, fatal, error, warning, info, or debug. Default is info and is reloadable. @@ -315,11 +346,11 @@ firewall: outbound_action: drop inbound_action: drop - # Controls the default value for local_cidr. Default is true, will be deprecated after v1.9 and defaulted to false. - # This setting only affects nebula hosts with subnets encoded in their certificate. A nebula host acting as an - # unsafe router with `default_local_cidr_any: true` will expose their unsafe routes to every inbound rule regardless - # of the actual destination for the packet. Setting this to false requires each inbound rule to contain a `local_cidr` - # if the intention is to allow traffic to flow to an unsafe route. + # THIS FLAG IS DEPRECATED AND WILL BE REMOVED IN A FUTURE RELEASE. (Defaults to false.) + # This setting only affects nebula hosts exposing unsafe_routes. When set to false, each inbound rule must contain a + # `local_cidr` if the intention is to allow traffic to flow to an unsafe route. When set to true, every firewall rule + # will apply to all configured unsafe_routes regardless of the actual destination of the packet, unless `local_cidr` + # is explicitly defined. This is usually not the desired behavior and should be avoided! #default_local_cidr_any: false conntrack: @@ -336,10 +367,10 @@ firewall: # host: `any` or a literal hostname, ie `test-host` # group: `any` or a literal group name, ie `default-group` # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass - # cidr: a remote CIDR, `0.0.0.0/0` is any. - # local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes. - # Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate - # if `default_local_cidr_any` is false, otherwise its `any`. + # cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. + # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes. + # By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true. + # If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case. # ca_name: An issuing CA name # ca_sha: An issuing CA shasum diff --git a/examples/go_service/main.go b/examples/go_service/main.go index f46273a..30178c0 100644 --- a/examples/go_service/main.go +++ b/examples/go_service/main.go @@ -4,6 +4,7 @@ import ( "bufio" "fmt" "log" + "net" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/service" @@ -54,16 +55,16 @@ pki: cert: /home/rice/Developer/nebula-config/app.crt key: /home/rice/Developer/nebula-config/app.key ` - var config config.C - if err := config.LoadString(configStr); err != nil { + var cfg config.C + if err := cfg.LoadString(configStr); err != nil { return err } - service, err := service.New(&config) + svc, err := service.New(&cfg) if err != nil { return err } - ln, err := service.Listen("tcp", ":1234") + ln, err := svc.Listen("tcp", ":1234") if err != nil { return err } @@ -73,16 +74,24 @@ pki: log.Printf("accept error: %s", err) break } - defer conn.Close() + defer func(conn net.Conn) { + _ = conn.Close() + }(conn) log.Printf("got connection") - conn.Write([]byte("hello world\n")) + _, err = conn.Write([]byte("hello world\n")) + if err != nil { + log.Printf("write error: %s", err) + } scanner := bufio.NewScanner(conn) for scanner.Scan() { message := scanner.Text() - fmt.Fprintf(conn, "echo: %q\n", message) + _, err = fmt.Fprintf(conn, "echo: %q\n", message) + if err != nil { + log.Printf("write error: %s", err) + } log.Printf("got message %q", message) } @@ -92,8 +101,8 @@ pki: } } - service.Close() - if err := service.Wait(); err != nil { + _ = svc.Close() + if err := svc.Wait(); err != nil { return err } return nil diff --git a/firewall.go b/firewall.go index d172bd9..eb57382 100644 --- a/firewall.go +++ b/firewall.go @@ -6,22 +6,22 @@ import ( "errors" "fmt" "hash/fnv" - "net" + "net/netip" "reflect" "strconv" "strings" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" ) type FirewallInterface interface { - AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error + AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error } type conn struct { @@ -50,10 +50,13 @@ type Firewall struct { UDPTimeout time.Duration //linux: 180s max DefaultTimeout time.Duration //linux: 600s - // Used to ensure we don't emit local packets for ips we don't own - localIps *cidr.Tree4[struct{}] - assignedCIDR *net.IPNet - hasSubnets bool + // routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate. + // The vpn addresses are a full bit match while the unsafe networks only match the prefix + routableNetworks *bart.Table[struct{}] + + // assignedNetworks is a list of vpn networks assigned to us in the certificate. + assignedNetworks []netip.Prefix + hasUnsafeNetworks bool rules string rulesVersion uint16 @@ -66,9 +69,9 @@ type Firewall struct { } type firewallMetrics struct { - droppedLocalIP metrics.Counter - droppedRemoteIP metrics.Counter - droppedNoRule metrics.Counter + droppedLocalAddr metrics.Counter + droppedRemoteAddr metrics.Counter + droppedNoRule metrics.Counter } type FirewallConntrack struct { @@ -107,7 +110,7 @@ type FirewallRule struct { Any *firewallLocalCIDR Hosts map[string]*firewallLocalCIDR Groups []*firewallGroups - CIDR *cidr.Tree4[*firewallLocalCIDR] + CIDR *bart.Table[*firewallLocalCIDR] } type firewallGroups struct { @@ -121,85 +124,92 @@ type firewallPort map[int32]*FirewallCA type firewallLocalCIDR struct { Any bool - LocalCIDR *cidr.Tree4[struct{}] + LocalCIDR *bart.Table[struct{}] } // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. -func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall { +// The certificate provided should be the highest version loaded in memory. +func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { //TODO: error on 0 duration - var min, max time.Duration + var tmin, tmax time.Duration if tcpTimeout < UDPTimeout { - min = tcpTimeout - max = UDPTimeout + tmin = tcpTimeout + tmax = UDPTimeout } else { - min = UDPTimeout - max = tcpTimeout + tmin = UDPTimeout + tmax = tcpTimeout } - if defaultTimeout < min { - min = defaultTimeout - } else if defaultTimeout > max { - max = defaultTimeout + if defaultTimeout < tmin { + tmin = defaultTimeout + } else if defaultTimeout > tmax { + tmax = defaultTimeout } - localIps := cidr.NewTree4[struct{}]() - var assignedCIDR *net.IPNet - for _, ip := range c.Details.Ips { - ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}} - localIps.AddCIDR(ipNet, struct{}{}) - - if assignedCIDR == nil { - // Only grabbing the first one in the cert since any more than that currently has undefined behavior - assignedCIDR = ipNet - } + routableNetworks := new(bart.Table[struct{}]) + var assignedNetworks []netip.Prefix + for _, network := range c.Networks() { + nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) + routableNetworks.Insert(nprefix, struct{}{}) + assignedNetworks = append(assignedNetworks, network) } - for _, n := range c.Details.Subnets { - localIps.AddCIDR(n, struct{}{}) + hasUnsafeNetworks := false + for _, n := range c.UnsafeNetworks() { + routableNetworks.Insert(n, struct{}{}) + hasUnsafeNetworks = true } return &Firewall{ Conntrack: &FirewallConntrack{ syncMutex: newSyncMutex("firewall-conntrack"), Conns: make(map[firewall.Packet]*conn), - TimerWheel: NewTimerWheel[firewall.Packet](min, max), + TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax), }, - InRules: newFirewallTable(), - OutRules: newFirewallTable(), - TCPTimeout: tcpTimeout, - UDPTimeout: UDPTimeout, - DefaultTimeout: defaultTimeout, - localIps: localIps, - assignedCIDR: assignedCIDR, - hasSubnets: len(c.Details.Subnets) > 0, - l: l, + InRules: newFirewallTable(), + OutRules: newFirewallTable(), + TCPTimeout: tcpTimeout, + UDPTimeout: UDPTimeout, + DefaultTimeout: defaultTimeout, + routableNetworks: routableNetworks, + assignedNetworks: assignedNetworks, + hasUnsafeNetworks: hasUnsafeNetworks, + l: l, incomingMetrics: firewallMetrics{ - droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil), - droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil), - droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil), + droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil), + droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil), + droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil), }, outgoingMetrics: firewallMetrics{ - droppedLocalIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil), - droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil), - droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil), + droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil), + droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil), + droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil), }, } } -func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) { +func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { + certificate := cs.getCertificate(cert.Version2) + if certificate == nil { + certificate = cs.getCertificate(cert.Version1) + } + + if certificate == nil { + panic("No certificate available to reconfigure the firewall") + } + fw := NewFirewall( l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), - nc, + certificate, //TODO: max_connections ) - //TODO: Flip to false after v1.9 release - fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true) + fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false) inboundAction := c.GetString("firewall.inbound_action", "drop") switch inboundAction { @@ -237,15 +247,15 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf } // AddRule properly creates the in memory rule structure for a firewall table. -func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { +func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error { // Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS // https://github.com/golang/go/issues/14131 sIp := "" - if ip != nil { + if ip.IsValid() { sIp = ip.String() } lIp := "" - if localIp != nil { + if localIp.IsValid() { lIp = localIp.String() } @@ -279,7 +289,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort fp = ft.TCP case firewall.ProtoUDP: fp = ft.UDP - case firewall.ProtoICMP: + case firewall.ProtoICMP, firewall.ProtoICMPv6: fp = ft.ICMP case firewall.ProtoAny: fp = ft.AnyProto @@ -321,7 +331,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw return nil } - rs, ok := r.([]interface{}) + rs, ok := r.([]any) if !ok { return fmt.Errorf("%s failed to parse, should be an array of rules", table) } @@ -382,17 +392,17 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) } - var cidr *net.IPNet + var cidr netip.Prefix if r.Cidr != "" { - _, cidr, err = net.ParseCIDR(r.Cidr) + cidr, err = netip.ParsePrefix(r.Cidr) if err != nil { return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err) } } - var localCidr *net.IPNet + var localCidr netip.Prefix if r.LocalCidr != "" { - _, localCidr, err = net.ParseCIDR(r.LocalCidr) + localCidr, err = netip.ParsePrefix(r.LocalCidr) if err != nil { return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err) } @@ -413,31 +423,31 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. -func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error { +func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { // Check if we spoke to this tuple, if we did then allow this packet if f.inConns(fp, h, caPool, localCache) { return nil } // Make sure remote address matches nebula certificate - if remoteCidr := h.remoteCidr; remoteCidr != nil { - ok, _ := remoteCidr.Contains(fp.RemoteIP) + if h.networks != nil { + _, ok := h.networks.Lookup(fp.RemoteAddr) if !ok { - f.metrics(incoming).droppedRemoteIP.Inc(1) + f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } } else { - // Simple case: Certificate has one IP and no subnets - if fp.RemoteIP != h.vpnIp { - f.metrics(incoming).droppedRemoteIP.Inc(1) + // Simple case: Certificate has one address and no unsafe networks + if h.vpnAddrs[0] != fp.RemoteAddr { + f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } } // Make sure we are supposed to be handling this local ip address - ok, _ := f.localIps.Contains(fp.LocalIP) + _, ok := f.routableNetworks.Lookup(fp.LocalAddr) if !ok { - f.metrics(incoming).droppedLocalIP.Inc(1) + f.metrics(incoming).droppedLocalAddr.Inc(1) return ErrInvalidLocalIP } @@ -482,7 +492,7 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } -func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool { +func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool { if localCache != nil { if _, ok := localCache[fp]; ok { return true @@ -589,7 +599,6 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Caller must own the connMutex lock! func (f *Firewall) evict(p firewall.Packet) { - //TODO: report a stat if the tcp rtt tracking was never resolved? // Are we still tracking this conn? conntrack := f.Conntrack t, ok := conntrack.Conns[p] @@ -610,7 +619,7 @@ func (f *Firewall) evict(p firewall.Packet) { delete(conntrack.Conns, p) } -func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { +func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool { if ft.AnyProto.match(p, incoming, c, caPool) { return true } @@ -624,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC if ft.UDP.match(p, incoming, c, caPool) { return true } - case firewall.ProtoICMP: + case firewall.ProtoICMP, firewall.ProtoICMPv6: if ft.ICMP.match(p, incoming, c, caPool) { return true } @@ -633,7 +642,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC return false } -func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { +func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error { if startPort > endPort { return fmt.Errorf("start port was lower than end port") } @@ -654,7 +663,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou return nil } -func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { +func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool { // We don't have any allowed ports, bail if fp == nil { return false @@ -677,12 +686,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer return fp[firewall.PortAny].match(p, c, caPool) } -func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error { +func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error { fr := func() *FirewallRule { return &FirewallRule{ Hosts: make(map[string]*firewallLocalCIDR), Groups: make([]*firewallGroups, 0), - CIDR: cidr.NewTree4[*firewallLocalCIDR](), + CIDR: new(bart.Table[*firewallLocalCIDR]), } } @@ -717,7 +726,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc return nil } -func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { +func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool *cert.CAPool) bool { if fc == nil { return false } @@ -726,24 +735,24 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool return true } - if t, ok := fc.CAShas[c.Details.Issuer]; ok { + if t, ok := fc.CAShas[c.Certificate.Issuer()]; ok { if t.match(p, c) { return true } } - s, err := caPool.GetCAForCert(c) + s, err := caPool.GetCAForCert(c.Certificate) if err != nil { return false } - return fc.CANames[s.Details.Name].match(p, c) + return fc.CANames[s.Certificate.Name()].match(p, c) } -func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error { +func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error { flc := func() *firewallLocalCIDR { return &firewallLocalCIDR{ - LocalCIDR: cidr.NewTree4[struct{}](), + LocalCIDR: new(bart.Table[struct{}]), } } @@ -780,8 +789,8 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n fr.Hosts[host] = nlc } - if ip != nil { - _, nlc := fr.CIDR.GetCIDR(ip) + if ip.IsValid() { + nlc, _ := fr.CIDR.Get(ip) if nlc == nil { nlc = flc() } @@ -789,14 +798,14 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n if err != nil { return err } - fr.CIDR.AddCIDR(ip, nlc) + fr.CIDR.Insert(ip, nlc) } return nil } -func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool { - if len(groups) == 0 && host == "" && ip == nil { +func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool { + if len(groups) == 0 && host == "" && !ip.IsValid() { return true } @@ -810,14 +819,14 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool return true } - if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) { + if ip.IsValid() && ip.Bits() == 0 { return true } return false } -func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool { +func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool { if fr == nil { return false } @@ -832,7 +841,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool found := false for _, g := range sg.Groups { - if _, ok := c.Details.InvertedGroups[g]; !ok { + if _, ok := c.InvertedGroups[g]; !ok { found = false break } @@ -846,35 +855,44 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool } if fr.Hosts != nil { - if flc, ok := fr.Hosts[c.Details.Name]; ok { + if flc, ok := fr.Hosts[c.Certificate.Name()]; ok { if flc.match(p, c) { return true } } } - return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool { - return flc.match(p, c) - }) + for _, v := range fr.CIDR.Supernets(netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())) { + if v.match(p, c) { + return true + } + } + + return false } -func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error { - if localIp == nil { - if !f.hasSubnets || f.defaultLocalCIDRAny { +func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { + if !localIp.IsValid() { + if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny { flc.Any = true return nil } - localIp = f.assignedCIDR - } else if localIp.Contains(net.IPv4(0, 0, 0, 0)) { + for _, network := range f.assignedNetworks { + flc.LocalCIDR.Insert(network, struct{}{}) + } + return nil + + } else if localIp.Bits() == 0 { flc.Any = true + return nil } - flc.LocalCIDR.AddCIDR(localIp, struct{}{}) + flc.LocalCIDR.Insert(localIp, struct{}{}) return nil } -func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool { +func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate) bool { if flc == nil { return false } @@ -883,7 +901,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate return true } - ok, _ := flc.LocalCIDR.Contains(p.LocalIP) + _, ok := flc.LocalCIDR.Lookup(p.LocalAddr) return ok } @@ -900,15 +918,15 @@ type rule struct { CASha string } -func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) { +func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) { r := rule{} - m, ok := p.(map[interface{}]interface{}) + m, ok := p.(map[string]any) if !ok { return r, errors.New("could not parse rule") } - toString := func(k string, m map[interface{}]interface{}) string { + toString := func(k string, m map[string]any) string { v, ok := m[k] if !ok { return "" @@ -926,7 +944,7 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er r.CASha = toString("ca_sha", m) // Make sure group isn't an array - if v, ok := m["group"].([]interface{}); ok { + if v, ok := m["group"].([]any); ok { if len(v) > 1 { return r, errors.New("group should contain a single value, an array with more than one entry was provided") } diff --git a/firewall/packet.go b/firewall/packet.go index 1c4affd..40c7fc5 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -3,25 +3,25 @@ package firewall import ( "encoding/json" "fmt" - - "github.com/slackhq/nebula/iputil" + "net/netip" ) -type m map[string]interface{} +type m = map[string]any const ( - ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever - ProtoTCP = 6 - ProtoUDP = 17 - ProtoICMP = 1 + ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever + ProtoTCP = 6 + ProtoUDP = 17 + ProtoICMP = 1 + ProtoICMPv6 = 58 PortAny = 0 // Special value for matching `port: any` PortFragment = -1 // Special value for matching `port: fragment` ) type Packet struct { - LocalIP iputil.VpnIp - RemoteIP iputil.VpnIp + LocalAddr netip.Addr + RemoteAddr netip.Addr LocalPort uint16 RemotePort uint16 Protocol uint8 @@ -30,8 +30,8 @@ type Packet struct { func (fp *Packet) Copy() *Packet { return &Packet{ - LocalIP: fp.LocalIP, - RemoteIP: fp.RemoteIP, + LocalAddr: fp.LocalAddr, + RemoteAddr: fp.RemoteAddr, LocalPort: fp.LocalPort, RemotePort: fp.RemotePort, Protocol: fp.Protocol, @@ -52,8 +52,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) { proto = fmt.Sprintf("unknown %v", fp.Protocol) } return json.Marshal(m{ - "LocalIP": fp.LocalIP.String(), - "RemoteIP": fp.RemoteIP.String(), + "LocalAddr": fp.LocalAddr.String(), + "RemoteAddr": fp.RemoteAddr.String(), "LocalPort": fp.LocalPort, "RemotePort": fp.RemotePort, "Protocol": proto, diff --git a/firewall_test.go b/firewall_test.go index b5beff6..4731a6f 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -4,21 +4,21 @@ import ( "bytes" "errors" "math" - "net" + "net/netip" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewFirewall(t *testing.T) { l := test.NewLogger() - c := &cert.NebulaCertificate{} + c := &dummyCert{} fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) conntrack := fw.Conntrack assert.NotNil(t, conntrack) @@ -60,64 +60,67 @@ func TestFirewall_AddRule(t *testing.T) { ob := &bytes.Buffer{} l.SetOutput(ob) - c := &cert.NebulaCertificate{} + c := &dummyCert{} fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.OutRules) - _, ti, _ := net.ParseCIDR("1.2.3.4/32") + ti, err := netip.ParsePrefix("1.2.3.4/32") + require.NoError(t, err) - assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) - ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti) + _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) - ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti) + _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha")) + require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - _, anyIp, _ := net.ParseCIDR("0.0.0.0/0") - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", "")) + anyIp, err := netip.ParsePrefix("0.0.0.0/0") + require.NoError(t, err) + + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", "")) - assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", "")) + require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) } func TestFirewall_Drop(t *testing.T) { @@ -126,79 +129,74 @@ func TestFirewall_Drop(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, Fragment: false, } - ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), - Mask: net.IPMask{255, 255, 255, 0}, - } - - c := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host1", - Ips: []*net.IPNet{&ipNet}, - Groups: []string{"default-group"}, - InvertedGroups: map[string]struct{}{"default-group": {}}, - Issuer: "signer-shasum", - }, + c := dummyCert{ + name: "host1", + networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", } h := HostInfo{ ConnectionState: &ConnectionState{ - peerCert: &c, + peerCert: &cert.CachedCertificate{ + Certificate: &c, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, } - h.CreateRemoteCIDR(&c) + h.buildNetworks(c.networks, c.unsafeNetworks) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // Allow outbound because conntrack - assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, false, &h, cp, nil)) // test remote mismatch - oldRemote := p.RemoteIP - p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10)) + oldRemote := p.RemoteAddr + p.RemoteAddr = netip.MustParseAddr("1.2.3.10") assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) - p.RemoteIP = oldRemote + p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum")) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks - cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} + cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match - cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} + cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", "")) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } func BenchmarkFirewallTable_match(b *testing.B) { @@ -207,15 +205,16 @@ func BenchmarkFirewallTable_match(b *testing.B) { TCP: firewallPort{}, } - _, n, _ := net.ParseCIDR("172.1.1.1/32") - goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP) - _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "") - _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "") + pfix := netip.MustParsePrefix("172.1.1.1/32") + _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "") + _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "") cp := cert.NewCAPool() b.Run("fail on proto", func(b *testing.B) { // This benchmark is showing us the cost of failing to match the protocol - c := &cert.NebulaCertificate{} + c := &cert.CachedCertificate{ + Certificate: &dummyCert{}, + } for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)) } @@ -223,29 +222,31 @@ func BenchmarkFirewallTable_match(b *testing.B) { b.Run("pass proto, fail on port", func(b *testing.B) { // This benchmark is showing us the cost of matching a specific protocol but failing to match the port - c := &cert.NebulaCertificate{} + c := &cert.CachedCertificate{ + Certificate: &dummyCert{}, + } for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)) } }) b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) { - c := &cert.NebulaCertificate{} - ip, _, _ := net.ParseCIDR("9.254.254.254/32") - lip := iputil.Ip2VpnIp(ip) + c := &cert.CachedCertificate{ + Certificate: &dummyCert{}, + } + ip := netip.MustParsePrefix("9.254.254.254/32") for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp)) } }) b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) { - _, ip, _ := net.ParseCIDR("9.254.254.254/32") - c := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"nope": {}}, - Name: "nope", - Ips: []*net.IPNet{ip}, + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", + networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")}, }, + InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) @@ -253,25 +254,24 @@ func BenchmarkFirewallTable_match(b *testing.B) { }) b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) { - _, ip, _ := net.ParseCIDR("9.254.254.254/32") - c := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"nope": {}}, - Name: "nope", - Ips: []*net.IPNet{ip}, + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", + networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")}, }, + InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) } }) b.Run("pass on group on any local cidr", func(b *testing.B) { - c := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"good-group": {}}, - Name: "nope", + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", }, + InvertedGroups: map[string]struct{}{"good-group": {}}, } for n := 0; n < b.N; n++ { assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) @@ -279,82 +279,28 @@ func BenchmarkFirewallTable_match(b *testing.B) { }) b.Run("pass on group on specific local cidr", func(b *testing.B) { - c := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"good-group": {}}, - Name: "nope", + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", }, + InvertedGroups: map[string]struct{}{"good-group": {}}, } for n := 0; n < b.N; n++ { - assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp)) + assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) } }) b.Run("pass on name", func(b *testing.B) { - c := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"nope": {}}, - Name: "good-host", + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "good-host", }, + InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) } }) - // - //b.Run("pass on ip", func(b *testing.B) { - // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) - // c := &cert.NebulaCertificate{ - // Details: cert.NebulaCertificateDetails{ - // InvertedGroups: map[string]struct{}{"nope": {}}, - // Name: "good-host", - // }, - // } - // for n := 0; n < b.N; n++ { - // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp) - // } - //}) - // - //b.Run("pass on local ip", func(b *testing.B) { - // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) - // c := &cert.NebulaCertificate{ - // Details: cert.NebulaCertificateDetails{ - // InvertedGroups: map[string]struct{}{"nope": {}}, - // Name: "good-host", - // }, - // } - // for n := 0; n < b.N; n++ { - // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp) - // } - //}) - // - //_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "") - // - //b.Run("pass on ip with any port", func(b *testing.B) { - // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) - // c := &cert.NebulaCertificate{ - // Details: cert.NebulaCertificateDetails{ - // InvertedGroups: map[string]struct{}{"nope": {}}, - // Name: "good-host", - // }, - // } - // for n := 0; n < b.N; n++ { - // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp) - // } - //}) - // - //b.Run("pass on local ip with any port", func(b *testing.B) { - // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) - // c := &cert.NebulaCertificate{ - // Details: cert.NebulaCertificateDetails{ - // InvertedGroups: map[string]struct{}{"nope": {}}, - // Name: "good-host", - // }, - // } - // for n := 0; n < b.N; n++ { - // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp) - // } - //}) } func TestFirewall_Drop2(t *testing.T) { @@ -363,57 +309,55 @@ func TestFirewall_Drop2(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, Fragment: false, } - ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), - Mask: net.IPMask{255, 255, 255, 0}, - } + network := netip.MustParsePrefix("1.2.3.4/24") - c := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host1", - Ips: []*net.IPNet{&ipNet}, - InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}}, + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, }, + InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}}, } h := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnAddrs: []netip.Addr{network.Addr()}, } - h.CreateRemoteCIDR(&c) + h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) - c1 := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host1", - Ips: []*net.IPNet{&ipNet}, - InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}}, + c1 := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, }, + InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}}, } h1 := HostInfo{ + vpnAddrs: []netip.Addr{network.Addr()}, ConnectionState: &ConnectionState{ peerCert: &c1, }, } - h1.CreateRemoteCIDR(&c1) + h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", "")) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups - assert.Error(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) + require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) // c has the proper groups resetConntrack(fw) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } func TestFirewall_Drop3(t *testing.T) { @@ -422,84 +366,85 @@ func TestFirewall_Drop3(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 1, RemotePort: 1, Protocol: firewall.ProtoUDP, Fragment: false, } - ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), - Mask: net.IPMask{255, 255, 255, 0}, - } - - c := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host-owner", - Ips: []*net.IPNet{&ipNet}, + network := netip.MustParsePrefix("1.2.3.4/24") + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host-owner", + networks: []netip.Prefix{network}, }, } - c1 := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host1", - Ips: []*net.IPNet{&ipNet}, - Issuer: "signer-sha-bad", + c1 := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, + issuer: "signer-sha-bad", }, } h1 := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c1, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnAddrs: []netip.Addr{network.Addr()}, } - h1.CreateRemoteCIDR(&c1) + h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) - c2 := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host2", - Ips: []*net.IPNet{&ipNet}, - Issuer: "signer-sha", + c2 := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host2", + networks: []netip.Prefix{network}, + issuer: "signer-sha", }, } h2 := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c2, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnAddrs: []netip.Addr{network.Addr()}, } - h2.CreateRemoteCIDR(&c2) + h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks()) - c3 := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host3", - Ips: []*net.IPNet{&ipNet}, - Issuer: "signer-sha-bad", + c3 := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host3", + networks: []netip.Prefix{network}, + issuer: "signer-sha-bad", }, } h3 := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c3, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnAddrs: []netip.Addr{network.Addr()}, } - h3.CreateRemoteCIDR(&c3) + h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha")) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match - assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) // c2 should pass because ca sha match resetConntrack(fw) - assert.NoError(t, fw.Drop(p, true, &h2, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h2, cp, nil)) // c3 should fail because no match resetConntrack(fw) assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) + + // Test a remote address match + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", "")) + require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) } func TestFirewall_DropConntrackReload(t *testing.T) { @@ -508,60 +453,56 @@ func TestFirewall_DropConntrackReload(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, Fragment: false, } + network := netip.MustParsePrefix("1.2.3.4/24") - ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), - Mask: net.IPMask{255, 255, 255, 0}, - } - - c := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host1", - Ips: []*net.IPNet{&ipNet}, - Groups: []string{"default-group"}, - InvertedGroups: map[string]struct{}{"default-group": {}}, - Issuer: "signer-shasum", + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, + groups: []string{"default-group"}, + issuer: "signer-shasum", }, + InvertedGroups: map[string]struct{}{"default-group": {}}, } h := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnAddrs: []netip.Addr{network.Addr()}, } - h.CreateRemoteCIDR(&c) + h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // Allow outbound because conntrack - assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw := fw - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", "")) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 // Allow outbound because conntrack and new rules allow port 10 - assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw = fw - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", "")) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -640,104 +581,105 @@ func BenchmarkLookup(b *testing.B) { ml(m, a) } }) - - //TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster } func Test_parsePort(t *testing.T) { _, _, err := parsePort("") - assert.EqualError(t, err, "was not a number; ``") + require.EqualError(t, err, "was not a number; ``") _, _, err = parsePort(" ") - assert.EqualError(t, err, "was not a number; ` `") + require.EqualError(t, err, "was not a number; ` `") _, _, err = parsePort("-") - assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`") + require.EqualError(t, err, "appears to be a range but could not be parsed; `-`") _, _, err = parsePort(" - ") - assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `") + require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `") _, _, err = parsePort("a-b") - assert.EqualError(t, err, "beginning range was not a number; `a`") + require.EqualError(t, err, "beginning range was not a number; `a`") _, _, err = parsePort("1-b") - assert.EqualError(t, err, "ending range was not a number; `b`") + require.EqualError(t, err, "ending range was not a number; `b`") s, e, err := parsePort(" 1 - 2 ") assert.Equal(t, int32(1), s) assert.Equal(t, int32(2), e) - assert.Nil(t, err) + require.NoError(t, err) s, e, err = parsePort("0-1") assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), e) - assert.Nil(t, err) + require.NoError(t, err) s, e, err = parsePort("9919") assert.Equal(t, int32(9919), s) assert.Equal(t, int32(9919), e) - assert.Nil(t, err) + require.NoError(t, err) s, e, err = parsePort("any") assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), e) - assert.Nil(t, err) + require.NoError(t, err) } func TestNewFirewallFromConfig(t *testing.T) { l := test.NewLogger() // Test a bad rule definition - c := &cert.NebulaCertificate{} + c := &dummyCert{} + cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) + require.NoError(t, err) + conf := config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} - _, err := NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") + conf.Settings["firewall"] = map[string]any{"outbound": "asdf"} + _, err = NewFirewallFromConfig(l, cs, conf) + require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} - _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}} + _, err = NewFirewallFromConfig(l, cs, conf) + require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} - _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}} + _, err = NewFirewallFromConfig(l, cs, conf) + require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}} + _, err = NewFirewallFromConfig(l, cs, conf) + require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}} + _, err = NewFirewallFromConfig(l, cs, conf) + require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}} + _, err = NewFirewallFromConfig(l, cs, conf) + require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh") + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}} + _, err = NewFirewallFromConfig(l, cs, conf) + require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh") + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}} + _, err = NewFirewallFromConfig(l, cs, conf) + require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} - _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} + _, err = NewFirewallFromConfig(l, cs, conf) + require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } func TestAddFirewallRulesFromConfig(t *testing.T) { @@ -745,87 +687,87 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { // Test adding tcp rule conf := config.NewC(l) mf := &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with cidr - cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)} + cidr := netip.MustParsePrefix("10.0.0.0/8") conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall) + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with local_cidr conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall) + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall) + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall) + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall) + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test Add error conf = config.NewC(l) mf = &mockFirewall{} mf.nextCallReturn = errors.New("test error") - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} - assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} + require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") } func TestFirewall_convertRule(t *testing.T) { @@ -834,33 +776,33 @@ func TestFirewall_convertRule(t *testing.T) { l.SetOutput(ob) // Ensure group array of 1 is converted and a warning is printed - c := map[interface{}]interface{}{ - "group": []interface{}{"group1"}, + c := map[string]any{ + "group": []any{"group1"}, } r, err := convertRule(l, c, "test", 1) assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "group1", r.Group) // Ensure group array of > 1 is errord ob.Reset() - c = map[interface{}]interface{}{ - "group": []interface{}{"group1", "group2"}, + c = map[string]any{ + "group": []any{"group1", "group2"}, } r, err = convertRule(l, c, "test", 1) - assert.Equal(t, "", ob.String()) - assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided") + assert.Empty(t, ob.String()) + require.Error(t, err, "group should contain a single value, an array with more than one entry was provided") // Make sure a well formed group is alright ob.Reset() - c = map[interface{}]interface{}{ + c = map[string]any{ "group": "group1", } r, err = convertRule(l, c, "test", 1) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "group1", r.Group) } @@ -871,8 +813,8 @@ type addRuleCall struct { endPort int32 groups []string host string - ip *net.IPNet - localIp *net.IPNet + ip netip.Prefix + localIp netip.Prefix caName string caSha string } @@ -882,7 +824,7 @@ type mockFirewall struct { nextCallReturn error } -func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { +func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error { mf.lastCall = addRuleCall{ incoming: incoming, proto: proto, diff --git a/go.mod b/go.mod index c5e7ddd..f61ebf7 100644 --- a/go.mod +++ b/go.mod @@ -1,55 +1,58 @@ module github.com/slackhq/nebula -go 1.22.0 +go 1.23.6 -toolchain go1.22.2 +toolchain go1.24.1 require ( - dario.cat/mergo v1.0.0 + dario.cat/mergo v1.0.1 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/armon/go-radix v1.0.0 github.com/clarkmcc/go-dag v0.0.0-20220908000337-9c3ba5b365fc github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/flynn/noise v1.1.0 + github.com/gaissmai/bart v0.20.1 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 - github.com/miekg/dns v1.1.59 + github.com/miekg/dns v1.1.64 + github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f - github.com/prometheus/client_golang v1.19.0 + github.com/prometheus/client_golang v1.21.1 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e - github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 - github.com/stretchr/testify v1.9.0 + github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 + github.com/stretchr/testify v1.10.0 github.com/timandy/routine v1.1.1 - github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.23.0 + github.com/vishvananda/netlink v1.3.0 + golang.org/x/crypto v0.36.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.25.0 - golang.org/x/sync v0.7.0 - golang.org/x/sys v0.20.0 - golang.org/x/term v0.20.0 + golang.org/x/net v0.38.0 + golang.org/x/sync v0.12.0 + golang.org/x/sys v0.31.0 + golang.org/x/term v0.30.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/protobuf v1.34.1 - gopkg.in/yaml.v2 v2.4.0 + google.golang.org/protobuf v1.36.6 + gopkg.in/yaml.v3 v3.0.1 gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe ) require ( github.com/beorn7/perks v1.0.1 // indirect - github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/btree v1.1.2 // indirect + github.com/klauspost/compress v1.17.11 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_model v0.5.0 // indirect - github.com/prometheus/common v0.48.0 // indirect - github.com/prometheus/procfs v0.12.0 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.62.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect github.com/vishvananda/netns v0.0.4 // indirect - golang.org/x/mod v0.16.0 // indirect + golang.org/x/mod v0.23.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.19.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect + golang.org/x/tools v0.30.0 // indirect ) diff --git a/go.sum b/go.sum index 44309d0..bc4c366 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,6 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= -dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= +dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= +dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -15,8 +15,8 @@ github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+Ce github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/clarkmcc/go-dag v0.0.0-20220908000337-9c3ba5b365fc h1:6e91sWiDE69Jl0WUsY/LvTCBPRBe6b2j8H7W96JGJ4s= github.com/clarkmcc/go-dag v0.0.0-20220908000337-9c3ba5b365fc/go.mod h1:RGIcF96ORCYAsdz60Ou9mPBNa4+DjoQFS8nelPniFoY= github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps= @@ -26,6 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= +github.com/gaissmai/bart v0.20.1 h1:igNss0zDsSY8e+ophKgD9KJVPKBOo7uSVjyKCL7nIzo= +github.com/gaissmai/bart v0.20.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= @@ -68,6 +70,8 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= @@ -78,13 +82,19 @@ github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3x github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= -github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk= +github.com/miekg/dns v1.1.64 h1:wuZgD9wwCE6XMT05UU/mlSko71eRSXEAm2EbjQXLKnQ= +github.com/miekg/dns v1.1.64/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck= +github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= +github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f h1:8dM0ilqKL0Uzl42GABzzC4Oqlc3kGRILz0vgoff7nwg= @@ -98,24 +108,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU= -github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k= +github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk= +github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= -github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= -github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE= -github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= +github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= +github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= -github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= @@ -127,21 +137,20 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= -github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= -github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= +github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw= +github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/timandy/routine v1.1.1 h1:6/Z7qLFZj3GrzuRksBFzIG8YGUh8CLhjnnMePBQTrEI= github.com/timandy/routine v1.1.1/go.mod h1:OZHPOKSvqL/ZvqXFkNZyit0xIVelERptYXdAHH00adQ= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= -github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= -github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= +github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -151,16 +160,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= -golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= +golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -171,8 +180,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -180,30 +189,30 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw= -golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= +golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= +golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -214,8 +223,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= -golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= +golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -234,8 +243,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= -google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -246,8 +255,6 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/handshake_ix.go b/handshake_ix.go index d53a6a8..ba5c777 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -1,13 +1,14 @@ package nebula import ( + "net/netip" + "slices" "time" "github.com/flynn/noise" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) // NOISE IX Handshakes @@ -17,40 +18,70 @@ import ( func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { err := f.handshakeManager.allocateIndex(hh) if err != nil { - f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") return false } - certState := f.pki.GetCertState() - ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0) + // If we're connecting to a v6 address we must use a v2 cert + cs := f.pki.getCertState() + v := cs.defaultVersion + for _, a := range hh.hostinfo.vpnAddrs { + if a.Is6() { + v = cert.Version2 + break + } + } + + crt := cs.getCertificate(v) + if crt == nil { + f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", v). + Error("Unable to handshake with host because no certificate is available") + return false + } + + crtHs := cs.getHandshakeBytes(v) + if crtHs == nil { + f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", v). + Error("Unable to handshake with host because no certificate handshake bytes is available") + } + + ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) + if err != nil { + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", v). + Error("Failed to create connection state") + return false + } hh.hostinfo.ConnectionState = ci - hsProto := &NebulaHandshakeDetails{ - InitiatorIndex: hh.hostinfo.localIndexId, - Time: uint64(time.Now().UnixNano()), - Cert: certState.RawCertificateNoKey, - } - - hsBytes := []byte{} - hs := &NebulaHandshake{ - Details: hsProto, + Details: &NebulaHandshakeDetails{ + InitiatorIndex: hh.hostinfo.localIndexId, + Time: uint64(time.Now().UnixNano()), + Cert: crtHs, + CertVersion: uint32(v), + }, } - hsBytes, err = hs.Marshal() + hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("certVersion", v). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") return false } h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1) - ci.messageCounter.Add(1) msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return false } @@ -64,67 +95,147 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { return true } -func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { - certState := f.pki.GetCertState() - ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0) +func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { + cs := f.pki.getCertState() + crt := cs.GetDefaultCertificate() + if crt == nil { + f.l.WithField("udpAddr", addr). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", cs.defaultVersion). + Error("Unable to handshake with host because no certificate is available") + } + + ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) + if err != nil { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed to create connection state") + return + } + // Mark packet 1 as seen so it doesn't show up as missed ci.window.Update(f.l, 1) msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed to call noise.ReadMessage") return } hs := &NebulaHandshake{} err = hs.Unmarshal(msg) - /* - l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex) - */ if err != nil || hs.Details == nil { f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed unmarshal handshake message") return } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) + rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { - e := f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) + f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Info("Handshake did not contain a certificate") + return + } - if f.l.Level > logrus.DebugLevel { - e = e.WithField("cert", remoteCert) + remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) + if err != nil { + fp, err := rc.Fingerprint() + if err != nil { + fp = "" + } + + e := f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + WithField("certVpnNetworks", rc.Networks()). + WithField("certFingerprint", fp) + + if f.l.Level >= logrus.DebugLevel { + e = e.WithField("cert", rc) } e.Info("Invalid certificate from host") return } - vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP) - certName := remoteCert.Details.Name - fingerprint, _ := remoteCert.Sha256Sum() - issuer := remoteCert.Details.Issuer - if vpnIp == f.myVpnIp { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") + if remoteCert.Certificate.Version() != ci.myCert.Version() { + // We started off using the wrong certificate version, lets see if we can match the version that was sent to us + rc := cs.getCertificate(remoteCert.Certificate.Version()) + if rc == nil { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). + Info("Unable to handshake with host due to missing certificate version") + return + } + + // Record the certificate we are actually using + ci.myCert = rc + } + + if len(remoteCert.Certificate.Networks()) == 0 { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("cert", remoteCert). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Info("No networks in certificate") return } - if addr != nil { - if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.IP) { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + var vpnAddrs []netip.Addr + var filteredNetworks []netip.Prefix + certName := remoteCert.Certificate.Name() + certVersion := remoteCert.Certificate.Version() + fingerprint := remoteCert.Fingerprint + issuer := remoteCert.Certificate.Issuer() + + for _, network := range remoteCert.Certificate.Networks() { + vpnAddr := network.Addr() + _, found := f.myVpnAddrsTable.Lookup(vpnAddr) + if found { + f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("certVersion", certVersion). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") + return + } + + // vpnAddrs outside our vpn networks are of no use to us, filter them out + if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok { + continue + } + + filteredNetworks = append(filteredNetworks, network) + vpnAddrs = append(vpnAddrs, vpnAddr) + } + + if len(vpnAddrs) == 0 { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("certVersion", certVersion). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") + return + } + + if addr.IsValid() { + // addr can be invalid when the tunnel is being relayed. + // We only want to apply the remote allow list for direct tunnels here + if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) { + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } } myIndex, err := generateIndex(f.l) if err != nil { - f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") @@ -136,19 +247,20 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by ConnectionState: ci, localIndexId: myIndex, remoteIndexId: hs.Details.InitiatorIndex, - vpnIp: vpnIp, + vpnAddrs: vpnAddrs, HandshakePacket: make(map[uint8][]byte, 0), lastHandshakeTime: hs.Details.Time, relayState: RelayState{ - syncRWMutex: newSyncRWMutex("relay-state"), - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + syncRWMutex: newSyncRWMutex("relay-state"), + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, } - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -156,14 +268,29 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by Info("Handshake message received") hs.Details.ResponderIndex = myIndex - hs.Details.Cert = certState.RawCertificateNoKey + hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) + if hs.Details.Cert == nil { + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("certVersion", certVersion). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + WithField("certVersion", ci.myCert.Version()). + Error("Unable to handshake with host because no certificate handshake bytes is available") + return + } + + hs.Details.CertVersion = uint32(ci.myCert.Version()) // Update the time in case their clock is way off from ours hs.Details.Time = uint64(time.Now().UnixNano()) hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") @@ -173,15 +300,17 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return } else if dKey == nil || eKey == nil { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key") @@ -204,9 +333,9 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by ci.dKey = NewNebulaCipherState(dKey) ci.eKey = NewNebulaCipherState(eKey) - hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) + hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.SetRemote(addr) - hostinfo.CreateRemoteCIDR(remoteCert) + hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) if err != nil { @@ -216,19 +345,19 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by if existing.SetRemoteIfPreferred(f.hostMap, addr) { // Send a test packet to ensure the other side has also switched to // the preferred remote - f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) } msg = existing.HandshakePacket[2] f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if addr != nil { + if addr.IsValid() { err := f.outside.WriteTo(msg, addr) if err != nil { - f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithError(err).Error("Failed to send handshake message") } else { - f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") } @@ -238,17 +367,18 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") return } case ErrExistingHostInfo: // This means there was an existing tunnel and this handshake was older than the one we are currently based on - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("oldHandshakeTime", existing.lastHandshakeTime). WithField("newHandshakeTime", hostinfo.lastHandshakeTime). WithField("fingerprint", fingerprint). @@ -258,24 +388,26 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by Info("Handshake too old") // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues - f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) return case ErrLocalIndexCollision: // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp). + WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs). Error("Failed to add HostInfo due to localIndex collision") return default: // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // And we forget to update it here - f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -287,19 +419,21 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by // Do the send f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if addr != nil { + if addr.IsValid() { err = f.outside.WriteTo(msg, addr) if err != nil { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithError(err).Error("Failed to send handshake") } else { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -311,10 +445,14 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) + // I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure + // it's correctly marked as working. + via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -323,13 +461,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by } f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) - hostinfo.ConnectionState.messageCounter.Store(2) + hostinfo.remotes.ResetBlockedRemotes() return } -func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { +func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { if hh == nil { // Nothing here to tear down, got a bogus stage 2 packet return true @@ -339,9 +477,10 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha defer hh.Unlock() hostinfo := hh.hostinfo - if addr != nil { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + if addr.IsValid() { + // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. + if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) { + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return false } } @@ -349,7 +488,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha ci := hostinfo.ConnectionState msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). Error("Failed to call noise.ReadMessage") @@ -358,7 +497,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha // near future return false } else if dKey == nil || eKey == nil { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Error("Noise did not arrive at a key") @@ -370,82 +509,57 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again return true } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) + rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { - e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) + f.l.WithError(err).WithField("udpAddr", addr). + WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + Info("Handshake did not contain a certificate") + return true + } - if f.l.Level > logrus.DebugLevel { - e = e.WithField("cert", remoteCert) + remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) + if err != nil { + fp, err := rc.Fingerprint() + if err != nil { + fp = "" } - e.Error("Invalid certificate from host") - - // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again - return true - } - - vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP) - certName := remoteCert.Details.Name - fingerprint, _ := remoteCert.Sha256Sum() - issuer := remoteCert.Details.Issuer - - // Ensure the right host responded - if vpnIp != hostinfo.vpnIp { - f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp). - WithField("udpAddr", addr).WithField("certName", certName). + e := f.l.WithError(err).WithField("udpAddr", addr). + WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Incorrect host responded to handshake") + WithField("certFingerprint", fp). + WithField("certVpnNetworks", rc.Networks()) - // Release our old handshake from pending, it should not continue - f.handshakeManager.DeleteHostInfo(hostinfo) - - // Create a new hostinfo/handshake for the intended vpn ip - f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) { - //TODO: this doesnt know if its being added or is being used for caching a packet - // Block the current used address - newHH.hostinfo.remotes = hostinfo.remotes - newHH.hostinfo.remotes.BlockRemote(addr) - - // Get the correct remote list for the host we did handshake with - hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) - - f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). - WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). - Info("Blocked addresses for handshakes") - - // Swap the packet store to benefit the original intended recipient - newHH.packetStore = hh.packetStore - hh.packetStore = []*cachedPacket{} - - // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down - hostinfo.vpnIp = vpnIp - f.sendCloseTunnel(hostinfo) - }) + if f.l.Level >= logrus.DebugLevel { + e = e.WithField("cert", rc) + } + e.Info("Invalid certificate from host") return true } - // Mark packet 2 as seen so it doesn't show up as missed - ci.window.Update(f.l, 2) + if len(remoteCert.Certificate.Networks()) == 0 { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("cert", remoteCert). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + Info("No networks in certificate") + return true + } - duration := time.Since(hh.startTime).Nanoseconds() - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("durationNs", duration). - WithField("sentCachedPackets", len(hh.packetStore)). - Info("Handshake message received") + vpnNetworks := remoteCert.Certificate.Networks() + certName := remoteCert.Certificate.Name() + certVersion := remoteCert.Certificate.Version() + fingerprint := remoteCert.Fingerprint + issuer := remoteCert.Certificate.Issuer() hostinfo.remoteIndexId = hs.Details.ResponderIndex hostinfo.lastHandshakeTime = hs.Details.Time @@ -456,21 +570,93 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha ci.eKey = NewNebulaCipherState(eKey) // Make sure the current udpAddr being used is set for responding - if addr != nil { + if addr.IsValid() { hostinfo.SetRemote(addr) } else { - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) } - // Build up the radix for the firewall if we have subnets in the cert - hostinfo.CreateRemoteCIDR(remoteCert) + var vpnAddrs []netip.Addr + var filteredNetworks []netip.Prefix + for _, network := range vpnNetworks { + // vpnAddrs outside our vpn networks are of no use to us, filter them out + vpnAddr := network.Addr() + if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok { + continue + } - // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp + filteredNetworks = append(filteredNetworks, network) + vpnAddrs = append(vpnAddrs, vpnAddr) + } + + if len(vpnAddrs) == 0 { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("certVersion", certVersion). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") + return true + } + + // Ensure the right host responded + if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) { + f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). + WithField("udpAddr", addr). + WithField("certName", certName). + WithField("certVersion", certVersion). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + Info("Incorrect host responded to handshake") + + // Release our old handshake from pending, it should not continue + f.handshakeManager.DeleteHostInfo(hostinfo) + + // Create a new hostinfo/handshake for the intended vpn ip + f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { + // Block the current used address + newHH.hostinfo.remotes = hostinfo.remotes + newHH.hostinfo.remotes.BlockRemote(addr) + + f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). + WithField("vpnNetworks", vpnNetworks). + WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). + Info("Blocked addresses for handshakes") + + // Swap the packet store to benefit the original intended recipient + newHH.packetStore = hh.packetStore + hh.packetStore = []*cachedPacket{} + + // Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down + hostinfo.vpnAddrs = vpnAddrs + f.sendCloseTunnel(hostinfo) + }) + + return true + } + + // Mark packet 2 as seen so it doesn't show up as missed + ci.window.Update(f.l, 2) + + duration := time.Since(hh.startTime).Nanoseconds() + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("certVersion", certVersion). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + WithField("durationNs", duration). + WithField("sentCachedPackets", len(hh.packetStore)). + Info("Handshake message received") + + // Build up the radix for the firewall if we have subnets in the cert + hostinfo.vpnAddrs = vpnAddrs + hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) + + // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here f.handshakeManager.Complete(hostinfo, f) f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) - hostinfo.ConnectionState.messageCounter.Store(2) - if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore)) } diff --git a/handshake_manager.go b/handshake_manager.go index 7222919..ee5b312 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -6,13 +6,14 @@ import ( "crypto/rand" "encoding/binary" "errors" - "net" + "net/netip" + "slices" "time" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" ) @@ -34,7 +35,7 @@ var ( type HandshakeConfig struct { tryInterval time.Duration - retries int + retries int64 triggerBuffer int useRelays bool @@ -45,14 +46,14 @@ type HandshakeManager struct { // Mutex for interacting with the vpnIps and indexes maps syncRWMutex - vpnIps map[iputil.VpnIp]*HandshakeHostInfo + vpnIps map[netip.Addr]*HandshakeHostInfo indexes map[uint32]*HandshakeHostInfo mainHostMap *HostMap lightHouse *LightHouse outside udp.Conn config HandshakeConfig - OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp] + OutboundHandshakeTimer *LockingTimerWheel[netip.Addr] messageMetrics *MessageMetrics metricInitiated metrics.Counter metricTimedOut metrics.Counter @@ -60,17 +61,17 @@ type HandshakeManager struct { l *logrus.Logger // can be used to trigger outbound handshake for the given vpnIp - trigger chan iputil.VpnIp + trigger chan netip.Addr } type HandshakeHostInfo struct { syncMutex - startTime time.Time // Time that we first started trying with this handshake - ready bool // Is the handshake ready - counter int // How many attempts have we made so far - lastRemotes []*udp.Addr // Remotes that we sent to during the previous attempt - packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes + startTime time.Time // Time that we first started trying with this handshake + ready bool // Is the handshake ready + counter int64 // How many attempts have we made so far + lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt + packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes hostinfo *HostInfo } @@ -103,14 +104,14 @@ func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ syncRWMutex: newSyncRWMutex("handshake-manager"), - vpnIps: map[iputil.VpnIp]*HandshakeHostInfo{}, + vpnIps: map[netip.Addr]*HandshakeHostInfo{}, indexes: map[uint32]*HandshakeHostInfo{}, mainHostMap: mainHostMap, lightHouse: lightHouse, outside: outside, config: config, - trigger: make(chan iputil.VpnIp, config.triggerBuffer), - OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp]("handshake-manager-timer", config.tryInterval, hsTimeout(config.retries, config.tryInterval)), + trigger: make(chan netip.Addr, config.triggerBuffer), + OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr]("handshake-manager-timer", config.tryInterval, hsTimeout(config.retries, config.tryInterval)), messageMetrics: config.messageMetrics, metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil), @@ -118,26 +119,26 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig } } -func (c *HandshakeManager) Run(ctx context.Context) { - clockSource := time.NewTicker(c.config.tryInterval) +func (hm *HandshakeManager) Run(ctx context.Context) { + clockSource := time.NewTicker(hm.config.tryInterval) defer clockSource.Stop() for { select { case <-ctx.Done(): return - case vpnIP := <-c.trigger: - c.handleOutbound(vpnIP, true) + case vpnIP := <-hm.trigger: + hm.handleOutbound(vpnIP, true) case now := <-clockSource.C: - c.NextOutboundHandshakeTimerTick(now) + hm.NextOutboundHandshakeTimerTick(now) } } } -func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { +func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { // First remote allow list check before we know the vpnIp - if addr != nil { - if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { + if addr.IsValid() { + if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) { hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } @@ -159,18 +160,18 @@ func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packe } } -func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { - c.OutboundHandshakeTimer.Advance(now) +func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { + hm.OutboundHandshakeTimer.Advance(now) for { - vpnIp, has := c.OutboundHandshakeTimer.Purge() + vpnIp, has := hm.OutboundHandshakeTimer.Purge() if !has { break } - c.handleOutbound(vpnIp, false) + hm.handleOutbound(vpnIp, false) } } -func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) { +func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered bool) { hh := hm.queryVpnIp(vpnIp) if hh == nil { return @@ -208,11 +209,11 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // NB ^ This comment doesn't jive. It's how the thing gets initialized. // It's the common path. Should it update every time, in case a future LH query/queries give us more info? if hostinfo.remotes == nil { - hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp) + hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp}) } remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()) - remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes) + remotesHaveChanged := !slices.Equal(remotes, hh.lastRemotes) // We only care about a lighthouse trigger if we have new remotes to send to. // This is a very specific optimization for a fast lighthouse reply. @@ -223,7 +224,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger hh.lastRemotes = remotes - // TODO: this will generate a load of queries for hosts with only 1 ip + // This will generate a load of queries for hosts with only 1 ip // (such as ones registered to the lighthouse with only a private IP) // So we only do it one time after attempting 5 handshakes already. if len(remotes) <= 1 && hh.counter == 5 { @@ -234,8 +235,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply - var sentTo []*udp.Addr - hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) { + var sentTo []netip.AddrPort + hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) { hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { @@ -256,7 +257,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Info("Handshake message sent") - } else if hm.l.IsLevelEnabled(logrus.DebugLevel) { + } else if hm.l.Level >= logrus.DebugLevel { hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). @@ -267,56 +268,28 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { - // Don't relay to myself, and don't relay through the host I'm trying to connect to - if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp { + // Don't relay to myself + if relay == vpnIp { continue } - relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay) - if relayHostInfo == nil || relayHostInfo.remote == nil { + + // Don't relay through the host I'm trying to connect to + _, found := hm.f.myVpnAddrsTable.Lookup(relay) + if found { + continue + } + + relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay) + if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") - hm.f.Handshake(*relay) + hm.f.Handshake(relay) continue } - // Check the relay HostInfo to see if we already established a relay through it - if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok { - switch existingRelay.State { - case Established: - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") - hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) - case Requested: - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") - // Re-send the CreateRelay request, in case the previous one was lost. - m := NebulaControl{ - Type: NebulaControl_CreateRelayRequest, - InitiatorRelayIndex: existingRelay.LocalIndex, - RelayFromIp: uint32(hm.lightHouse.myVpnIp), - RelayToIp: uint32(vpnIp), - } - msg, err := m.Marshal() - if err != nil { - hostinfo.logger(hm.l). - WithError(err). - Error("Failed to marshal Control message to create relay") - } else { - // This must send over the hostinfo, not over hm.Hosts[ip] - hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.lightHouse.myVpnIp, - "relayTo": vpnIp, - "initiatorRelayIndex": existingRelay.LocalIndex, - "relay": *relay}). - Info("send CreateRelayRequest") - } - default: - hostinfo.logger(hm.l). - WithField("vpnIp", vpnIp). - WithField("state", existingRelay.State). - WithField("relay", relayHostInfo.vpnIp). - Errorf("Relay unexpected state") - } - } else { + // Check the relay HostInfo to see if we already established a relay through + existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp) + if !ok { // No relays exist or requested yet. - if relayHostInfo.remote != nil { + if relayHostInfo.remote.IsValid() { idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) if err != nil { hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") @@ -325,9 +298,32 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: idx, - RelayFromIp: uint32(hm.lightHouse.myVpnIp), - RelayToIp: uint32(vpnIp), } + + switch relayHostInfo.GetCert().Certificate.Version() { + case cert.Version1: + if !hm.f.myVpnAddrs[0].Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !vpnIp.Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := hm.f.myVpnAddrs[0].As4() + m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = vpnIp.As4() + m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) + m.RelayToAddr = netAddrToProtoAddr(vpnIp) + default: + hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") + continue + } + msg, err := m.Marshal() if err != nil { hostinfo.logger(hm.l). @@ -336,13 +332,80 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } else { hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.lightHouse.myVpnIp, + "relayFrom": hm.f.myVpnAddrs[0], "relayTo": vpnIp, "initiatorRelayIndex": idx, - "relay": *relay}). + "relay": relay}). Info("send CreateRelayRequest") } } + continue + } + + switch existingRelay.State { + case Established: + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") + hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) + case Disestablished: + // Mark this relay as 'requested' + relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) + fallthrough + case Requested: + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") + // Re-send the CreateRelay request, in case the previous one was lost. + m := NebulaControl{ + Type: NebulaControl_CreateRelayRequest, + InitiatorRelayIndex: existingRelay.LocalIndex, + } + + switch relayHostInfo.GetCert().Certificate.Version() { + case cert.Version1: + if !hm.f.myVpnAddrs[0].Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !vpnIp.Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := hm.f.myVpnAddrs[0].As4() + m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = vpnIp.As4() + m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) + m.RelayToAddr = netAddrToProtoAddr(vpnIp) + default: + hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") + continue + } + msg, err := m.Marshal() + if err != nil { + hostinfo.logger(hm.l). + WithError(err). + Error("Failed to marshal Control message to create relay") + } else { + // This must send over the hostinfo, not over hm.Hosts[ip] + hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + hm.l.WithFields(logrus.Fields{ + "relayFrom": hm.f.myVpnAddrs[0], + "relayTo": vpnIp, + "initiatorRelayIndex": existingRelay.LocalIndex, + "relay": relay}). + Info("send CreateRelayRequest") + } + case PeerRequested: + // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. + fallthrough + default: + hostinfo.logger(hm.l). + WithField("vpnIp", vpnIp). + WithField("state", existingRelay.State). + WithField("relay", relay). + Errorf("Relay unexpected state") + } } } @@ -355,11 +418,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present // The 2nd argument will be true if the hostinfo is ready to transmit traffic -func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) { - // Check the main hostmap and maintain a read lock if our host is not there +func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) { hm.mainHostMap.RLock() - if h, ok := hm.mainHostMap.Hosts[vpnIp]; ok { - hm.mainHostMap.RUnlock() + h, ok := hm.mainHostMap.Hosts[vpnIp] + hm.mainHostMap.RUnlock() + + if ok { // Do not attempt promotion if you are a lighthouse if !hm.lightHouse.amLighthouse { h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f) @@ -367,15 +431,14 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han return h, true } - defer hm.mainHostMap.RUnlock() return hm.StartHandshake(vpnIp, cacheCb), false } // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip -func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo { +func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo { hm.Lock() - if hh, ok := hm.vpnIps[vpnIp]; ok { + if hh, ok := hm.vpnIps[vpnAddr]; ok { // We are already trying to handshake with this vpn ip if cacheCb != nil { cacheCb(hh) @@ -386,13 +449,13 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han hostinfo := &HostInfo{ syncRWMutex: newSyncRWMutex("hostinfo"), - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnAddr}, HandshakePacket: make(map[uint8][]byte, 0), relayState: RelayState{ - syncRWMutex: newSyncRWMutex("relay-state"), - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + syncRWMutex: newSyncRWMutex("relay-state"), + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, } @@ -401,9 +464,9 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han hostinfo: hostinfo, startTime: time.Now(), } - hm.vpnIps[vpnIp] = hh + hm.vpnIps[vpnAddr] = hh hm.metricInitiated.Inc(1) - hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval) + hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval) if cacheCb != nil { cacheCb(hh) @@ -411,21 +474,21 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han // If this is a static host, we don't need to wait for the HostQueryReply // We can trigger the handshake right now - _, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp] + _, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr] if !doTrigger { // Add any calculated remotes, and trigger early handshake if one found - doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp) + doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr) } if doTrigger { select { - case hm.trigger <- vpnIp: + case hm.trigger <- vpnAddr: default: } } hm.Unlock() - hm.lightHouse.QueryServer(vpnIp) + hm.lightHouse.QueryServer(vpnAddr) return hostinfo } @@ -446,14 +509,14 @@ var ( // // ErrLocalIndexCollision if we already have an entry in the main or pending // hostmap for the hostinfo.localIndexId. -func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { - c.mainHostMap.Lock() - defer c.mainHostMap.Unlock() - c.Lock() - defer c.Unlock() +func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { + hm.mainHostMap.Lock() + defer hm.mainHostMap.Unlock() + hm.Lock() + defer hm.Unlock() // Check if we already have a tunnel with this vpn ip - existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp] + existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]] if found && existingHostInfo != nil { testHostInfo := existingHostInfo for testHostInfo != nil { @@ -470,31 +533,31 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket return existingHostInfo, ErrExistingHostInfo } - existingHostInfo.logger(c.l).Info("Taking new handshake") + existingHostInfo.logger(hm.l).Info("Taking new handshake") } - existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId] + existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId] if found { // We have a collision, but for a different hostinfo return existingIndex, ErrLocalIndexCollision } - existingPendingIndex, found := c.indexes[hostinfo.localIndexId] + existingPendingIndex, found := hm.indexes[hostinfo.localIndexId] if found && existingPendingIndex.hostinfo != hostinfo { // We have a collision, but for a different hostinfo - return existingIndex, ErrLocalIndexCollision + return existingPendingIndex.hostinfo, ErrLocalIndexCollision } - existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] - if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp { + existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] + if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(c.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). + hostinfo.logger(hm.l). + WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). Info("New host shadows existing host remoteIndex") } - c.mainHostMap.unlockedAddHostInfo(hostinfo, f) + hm.mainHostMap.unlockedAddHostInfo(hostinfo, f) return existingHostInfo, nil } @@ -512,7 +575,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. hostinfo.logger(hm.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). + WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). Info("New host shadows existing host remoteIndex") } @@ -549,31 +612,34 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { return errors.New("failed to generate unique localIndexId") } -func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { - c.Lock() - defer c.Unlock() - c.unlockedDeleteHostInfo(hostinfo) +func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { + hm.Lock() + defer hm.Unlock() + hm.unlockedDeleteHostInfo(hostinfo) } -func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { - delete(c.vpnIps, hostinfo.vpnIp) - if len(c.vpnIps) == 0 { - c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{} +func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { + for _, addr := range hostinfo.vpnAddrs { + delete(hm.vpnIps, addr) } - delete(c.indexes, hostinfo.localIndexId) - if len(c.vpnIps) == 0 { - c.indexes = map[uint32]*HandshakeHostInfo{} + if len(hm.vpnIps) == 0 { + hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{} } - if c.l.Level >= logrus.DebugLevel { - c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps), - "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). + delete(hm.indexes, hostinfo.localIndexId) + if len(hm.indexes) == 0 { + hm.indexes = map[uint32]*HandshakeHostInfo{} + } + + if hm.l.Level >= logrus.DebugLevel { + hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps), + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Pending hostmap hostInfo deleted") } } -func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { +func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo { hh := hm.queryVpnIp(vpnIp) if hh != nil { return hh.hostinfo @@ -582,7 +648,7 @@ func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { } -func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo { +func (hm *HandshakeManager) queryVpnIp(vpnIp netip.Addr) *HandshakeHostInfo { hm.RLock() defer hm.RUnlock() return hm.vpnIps[vpnIp] @@ -602,37 +668,37 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo { return hm.indexes[index] } -func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet { - return c.mainHostMap.GetPreferredRanges() +func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix { + return hm.mainHostMap.GetPreferredRanges() } -func (c *HandshakeManager) ForEachVpnIp(f controlEach) { - c.RLock() - defer c.RUnlock() +func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) { + hm.RLock() + defer hm.RUnlock() - for _, v := range c.vpnIps { + for _, v := range hm.vpnIps { f(v.hostinfo) } } -func (c *HandshakeManager) ForEachIndex(f controlEach) { - c.RLock() - defer c.RUnlock() +func (hm *HandshakeManager) ForEachIndex(f controlEach) { + hm.RLock() + defer hm.RUnlock() - for _, v := range c.indexes { + for _, v := range hm.indexes { f(v.hostinfo) } } -func (c *HandshakeManager) EmitStats() { - c.RLock() - hostLen := len(c.vpnIps) - indexLen := len(c.indexes) - c.RUnlock() +func (hm *HandshakeManager) EmitStats() { + hm.RLock() + hostLen := len(hm.vpnIps) + indexLen := len(hm.indexes) + hm.RUnlock() metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen)) metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen)) - c.mainHostMap.EmitStats() + hm.mainHostMap.EmitStats() } // Utility functions below @@ -659,6 +725,6 @@ func generateIndex(l *logrus.Logger) (uint32, error) { return index, nil } -func hsTimeout(tries int, interval time.Duration) time.Duration { - return time.Duration(tries / 2 * ((2 * int(interval)) + (tries-1)*int(interval))) +func hsTimeout(tries int64, interval time.Duration) time.Duration { + return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval))) } diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 9a63357..4b898af 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -1,13 +1,12 @@ package nebula import ( - "net" + "net/netip" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" @@ -15,20 +14,20 @@ import ( func Test_NewHandshakeManagerVpnIp(t *testing.T) { l := test.NewLogger() - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) - preferredRanges := []*net.IPNet{localrange} - mainHM := newHostMap(l, vpncidr) + localrange := netip.MustParsePrefix("10.1.1.1/24") + ip := netip.MustParseAddr("172.1.1.2") + + preferredRanges := []netip.Prefix{localrange} + mainHM := newHostMap(l) mainHM.preferredRanges.Store(&preferredRanges) lh := newTestLighthouse() cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &cert.NebulaCertificate{}, - RawCertificateNoKey: []byte{}, + defaultVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) @@ -42,10 +41,10 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { i2 := blah.StartHandshake(ip, nil) assert.Same(t, i, i2) - i.remotes = NewRemoteList(nil) + i.remotes = NewRemoteList([]netip.Addr{}, nil) // Adding something to pending should not affect the main hostmap - assert.Len(t, mainHM.Hosts, 0) + assert.Empty(t, mainHM.Hosts) // Confirm they are in the pending index list assert.Contains(t, blah.vpnIps, ip) @@ -66,7 +65,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { assert.NotContains(t, blah.vpnIps, ip) } -func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { +func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) { for _, i := range tw.t.wheel { n := i.Head for n != nil { @@ -80,16 +79,24 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { type mockEncWriter struct { } -func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { +func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) { return } -func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { +func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) { return } -func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { +func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) { return } -func (mw *mockEncWriter) Handshake(vpnIP iputil.VpnIp) {} +func (mw *mockEncWriter) Handshake(_ netip.Addr) {} + +func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo { + return nil +} + +func (mw *mockEncWriter) GetCertState() *CertState { + return &CertState{defaultVersion: cert.Version2} +} diff --git a/header/header.go b/header/header.go index 50b7d62..f22509b 100644 --- a/header/header.go +++ b/header/header.go @@ -19,7 +19,7 @@ import ( // |-----------------------------------------------------------------------| // | payload... | -type m map[string]interface{} +type m = map[string]any const ( Version uint8 = 1 diff --git a/header/header_test.go b/header/header_test.go index 765a006..a7e5374 100644 --- a/header/header_test.go +++ b/header/header_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type headerTest struct { @@ -111,7 +112,7 @@ func TestHeader_String(t *testing.T) { func TestHeader_MarshalJSON(t *testing.T) { b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON() - assert.Nil(t, err) + require.NoError(t, err) assert.Equal( t, "{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}", diff --git a/hostmap.go b/hostmap.go index 362a76a..7c1fe6a 100644 --- a/hostmap.go +++ b/hostmap.go @@ -3,17 +3,16 @@ package nebula import ( "errors" "net" + "net/netip" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) // const ProbeLen = 100 @@ -35,6 +34,7 @@ const ( Requested = iota PeerRequested Established + Disestablished ) const ( @@ -48,7 +48,7 @@ type Relay struct { State int LocalIndex uint32 RemoteIndex uint32 - PeerIp iputil.VpnIp + PeerAddr netip.Addr } type HostMap struct { @@ -56,9 +56,8 @@ type HostMap struct { Indexes map[uint32]*HostInfo Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object RemoteIndexes map[uint32]*HostInfo - Hosts map[iputil.VpnIp]*HostInfo - preferredRanges atomic.Pointer[[]*net.IPNet] - vpnCIDR *net.IPNet + Hosts map[netip.Addr]*HostInfo + preferredRanges atomic.Pointer[[]netip.Prefix] l *logrus.Logger } @@ -68,17 +67,42 @@ type HostMap struct { type RelayState struct { syncRWMutex - relays map[iputil.VpnIp]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer - relayForByIp map[iputil.VpnIp]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info - relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info + relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer + // For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data, + // modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with + // the RelayState Lock held) + relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info + relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info } -func (rs *RelayState) DeleteRelay(ip iputil.VpnIp) { +func (rs *RelayState) DeleteRelay(ip netip.Addr) { rs.Lock() defer rs.Unlock() delete(rs.relays, ip) } +func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) { + rs.Lock() + defer rs.Unlock() + if r, ok := rs.relayForByAddr[vpnIp]; ok { + newRelay := *r + newRelay.State = state + rs.relayForByAddr[newRelay.PeerAddr] = &newRelay + rs.relayForByIdx[newRelay.LocalIndex] = &newRelay + } +} + +func (rs *RelayState) UpdateRelayForByIdxState(idx uint32, state int) { + rs.Lock() + defer rs.Unlock() + if r, ok := rs.relayForByIdx[idx]; ok { + newRelay := *r + newRelay.State = state + rs.relayForByAddr[newRelay.PeerAddr] = &newRelay + rs.relayForByIdx[newRelay.LocalIndex] = &newRelay + } +} + func (rs *RelayState) CopyAllRelayFor() []*Relay { rs.RLock() defer rs.RUnlock() @@ -89,34 +113,34 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay { return ret } -func (rs *RelayState) GetRelayForByIp(ip iputil.VpnIp) (*Relay, bool) { +func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() - r, ok := rs.relayForByIp[ip] + r, ok := rs.relayForByAddr[addr] return r, ok } -func (rs *RelayState) InsertRelayTo(ip iputil.VpnIp) { +func (rs *RelayState) InsertRelayTo(ip netip.Addr) { rs.Lock() defer rs.Unlock() rs.relays[ip] = struct{}{} } -func (rs *RelayState) CopyRelayIps() []iputil.VpnIp { +func (rs *RelayState) CopyRelayIps() []netip.Addr { rs.RLock() defer rs.RUnlock() - ret := make([]iputil.VpnIp, 0, len(rs.relays)) + ret := make([]netip.Addr, 0, len(rs.relays)) for ip := range rs.relays { ret = append(ret, ip) } return ret } -func (rs *RelayState) CopyRelayForIps() []iputil.VpnIp { +func (rs *RelayState) CopyRelayForIps() []netip.Addr { rs.RLock() defer rs.RUnlock() - currentRelays := make([]iputil.VpnIp, 0, len(rs.relayForByIp)) - for relayIp := range rs.relayForByIp { + currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr)) + for relayIp := range rs.relayForByAddr { currentRelays = append(currentRelays, relayIp) } return currentRelays @@ -132,22 +156,10 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 { return ret } -func (rs *RelayState) RemoveRelay(localIdx uint32) (iputil.VpnIp, bool) { +func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool { rs.Lock() defer rs.Unlock() - r, ok := rs.relayForByIdx[localIdx] - if !ok { - return iputil.VpnIp(0), false - } - delete(rs.relayForByIdx, localIdx) - delete(rs.relayForByIp, r.PeerIp) - return r.PeerIp, true -} - -func (rs *RelayState) CompleteRelayByIP(vpnIp iputil.VpnIp, remoteIdx uint32) bool { - rs.Lock() - defer rs.Unlock() - r, ok := rs.relayForByIp[vpnIp] + r, ok := rs.relayForByAddr[vpnIp] if !ok { return false } @@ -155,7 +167,7 @@ func (rs *RelayState) CompleteRelayByIP(vpnIp iputil.VpnIp, remoteIdx uint32) bo newRelay.State = Established newRelay.RemoteIndex = remoteIdx rs.relayForByIdx[r.LocalIndex] = &newRelay - rs.relayForByIp[r.PeerIp] = &newRelay + rs.relayForByAddr[r.PeerAddr] = &newRelay return true } @@ -170,14 +182,14 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re newRelay.State = Established newRelay.RemoteIndex = remoteIdx rs.relayForByIdx[r.LocalIndex] = &newRelay - rs.relayForByIp[r.PeerIp] = &newRelay + rs.relayForByAddr[r.PeerAddr] = &newRelay return &newRelay, true } -func (rs *RelayState) QueryRelayForByIp(vpnIp iputil.VpnIp) (*Relay, bool) { +func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() - r, ok := rs.relayForByIp[vpnIp] + r, ok := rs.relayForByAddr[vpnIp] return r, ok } @@ -188,25 +200,31 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) { return r, ok } -func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) { +func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { rs.Lock() defer rs.Unlock() - rs.relayForByIp[ip] = r + rs.relayForByAddr[ip] = r rs.relayForByIdx[idx] = r } type HostInfo struct { syncRWMutex - remote *udp.Addr + remote netip.AddrPort remotes *RemoteList promoteCounter atomic.Uint32 ConnectionState *ConnectionState remoteIndexId uint32 localIndexId uint32 - vpnIp iputil.VpnIp - recvError atomic.Uint32 - remoteCidr *cidr.Tree4[struct{}] - relayState RelayState + + // vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks + // The host may have other vpn addresses that are outside our + // vpn networks but were removed because they are not usable + vpnAddrs []netip.Addr + recvError atomic.Uint32 + + // networks are both all vpn and unsafe networks assigned to this host + networks *bart.Table[struct{}] + relayState RelayState // HandshakePacket records the packets used to create this hostinfo // We need these to avoid replayed handshake packets creating new hostinfos which causes churn @@ -227,7 +245,7 @@ type HostInfo struct { lastHandshakeTime uint64 lastRoam time.Time - lastRoamRemote *udp.Addr + lastRoamRemote netip.AddrPort // Used to track other hostinfos for this vpn ip since only 1 can be primary // Synchronised via hostmap lock and not the hostinfo lock. @@ -254,40 +272,38 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap { - hm := newHostMap(l, vpnCIDR) +func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap { + hm := newHostMap(l) hm.reload(c, true) c.RegisterReloadCallback(func(c *config.C) { hm.reload(c, false) }) - l.WithField("network", hm.vpnCIDR.String()). - WithField("preferredRanges", hm.GetPreferredRanges()). + l.WithField("preferredRanges", hm.GetPreferredRanges()). Info("Main HostMap created") return hm } -func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap { +func newHostMap(l *logrus.Logger) *HostMap { return &HostMap{ syncRWMutex: newSyncRWMutex("hostmap"), Indexes: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{}, RemoteIndexes: map[uint32]*HostInfo{}, - Hosts: map[iputil.VpnIp]*HostInfo{}, - vpnCIDR: vpnCIDR, + Hosts: map[netip.Addr]*HostInfo{}, l: l, } } func (hm *HostMap) reload(c *config.C, initial bool) { if initial || c.HasChanged("preferred_ranges") { - var preferredRanges []*net.IPNet + var preferredRanges []netip.Prefix rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{}) for _, rawPreferredRange := range rawPreferredRanges { - _, preferredRange, err := net.ParseCIDR(rawPreferredRange) + preferredRange, err := netip.ParsePrefix(rawPreferredRange) if err != nil { hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring") @@ -319,17 +335,6 @@ func (hm *HostMap) EmitStats() { metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen)) } -func (hm *HostMap) RemoveRelay(localIdx uint32) { - hm.Lock() - _, ok := hm.Relays[localIdx] - if !ok { - hm.Unlock() - return - } - delete(hm.Relays, localIdx) - hm.Unlock() -} - // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { // Delete the host itself, ensuring it's not modified anymore @@ -349,48 +354,73 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) { } func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { - oldHostinfo := hm.Hosts[hostinfo.vpnIp] + // Get the current primary, if it exists + oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]] + + // Every address in the hostinfo gets elevated to primary + for _, vpnAddr := range hostinfo.vpnAddrs { + //NOTE: It is possible that we leave a dangling hostinfo here but connection manager works on + // indexes so it should be fine. + hm.Hosts[vpnAddr] = hostinfo + } + + // If we are already primary then we won't bother re-linking if oldHostinfo == hostinfo { return } + // Unlink this hostinfo if hostinfo.prev != nil { hostinfo.prev.next = hostinfo.next } - if hostinfo.next != nil { hostinfo.next.prev = hostinfo.prev } - hm.Hosts[hostinfo.vpnIp] = hostinfo - + // If there wasn't a previous primary then clear out any links if oldHostinfo == nil { + hostinfo.next = nil + hostinfo.prev = nil return } + // Relink the hostinfo as primary hostinfo.next = oldHostinfo oldHostinfo.prev = hostinfo hostinfo.prev = nil } func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { - primary, ok := hm.Hosts[hostinfo.vpnIp] + for _, addr := range hostinfo.vpnAddrs { + h := hm.Hosts[addr] + for h != nil { + if h == hostinfo { + hm.unlockedInnerDeleteHostInfo(h, addr) + } + h = h.next + } + } +} + +func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) { + primary, ok := hm.Hosts[addr] + isLastHostinfo := hostinfo.next == nil && hostinfo.prev == nil if ok && primary == hostinfo { - // The vpnIp pointer points to the same hostinfo as the local index id, we can remove it - delete(hm.Hosts, hostinfo.vpnIp) + // The vpn addr pointer points to the same hostinfo as the local index id, we can remove it + delete(hm.Hosts, addr) if len(hm.Hosts) == 0 { - hm.Hosts = map[iputil.VpnIp]*HostInfo{} + hm.Hosts = map[netip.Addr]*HostInfo{} } if hostinfo.next != nil { - // We had more than 1 hostinfo at this vpnip, promote the next in the list to primary - hm.Hosts[hostinfo.vpnIp] = hostinfo.next + // We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary + hm.Hosts[addr] = hostinfo.next // It is primary, there is no previous hostinfo now hostinfo.next.prev = nil } } else { - // Relink if we were in the middle of multiple hostinfos for this vpn ip + // Relink if we were in the middle of multiple hostinfos for this vpn addr if hostinfo.prev != nil { hostinfo.prev.next = hostinfo.next } @@ -420,10 +450,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { if hm.l.Level >= logrus.DebugLevel { hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts), - "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Hostmap hostInfo deleted") } + if isLastHostinfo { + // I have lost connectivity to my peers. My relay tunnel is likely broken. Mark the next + // hops as 'Requested' so that new relay tunnels are created in the future. + hm.unlockedDisestablishVpnAddrRelayFor(hostinfo) + } + // Clean up any local relay indexes for which I am acting as a relay hop for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() { delete(hm.Relays, localRelayIdx) } @@ -462,11 +498,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo { } } -func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { - return hm.queryVpnIp(vpnIp, nil) +func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo { + return hm.queryVpnAddr(vpnIp, nil) } -func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*HostInfo, *Relay, error) { +func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) { hm.RLock() defer hm.RUnlock() @@ -474,17 +510,42 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host if !ok { return nil, nil, errors.New("unable to find host") } + for h != nil { - r, ok := h.relayState.QueryRelayForByIp(targetIp) - if ok && r.State == Established { - return h, r, nil + for _, targetIp := range targetIps { + r, ok := h.relayState.QueryRelayForByIp(targetIp) + if ok && r.State == Established { + return h, r, nil + } } h = h.next } + return nil, nil, errors.New("unable to find host with relay") } -func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo { +func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) { + for _, relayHostIp := range hi.relayState.CopyRelayIps() { + if h, ok := hm.Hosts[relayHostIp]; ok { + for h != nil { + h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished) + h = h.next + } + } + } + for _, rs := range hi.relayState.CopyAllRelayFor() { + if rs.Type == ForwardingType { + if h, ok := hm.Hosts[rs.PeerAddr]; ok { + for h != nil { + h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished) + h = h.next + } + } + } + } +} + +func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo { hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { hm.RUnlock() @@ -505,25 +566,30 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostI func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { if f.serveDns { remoteCert := hostinfo.ConnectionState.peerCert - dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) + dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) } - - existing := hm.Hosts[hostinfo.vpnIp] - hm.Hosts[hostinfo.vpnIp] = hostinfo - - if existing != nil { - hostinfo.next = existing - existing.prev = hostinfo + for _, addr := range hostinfo.vpnAddrs { + hm.unlockedInnerAddHostInfo(addr, hostinfo, f) } hm.Indexes[hostinfo.localIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts), - "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}). + hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts), + "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}). Debug("Hostmap vpnIp added") } +} + +func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) { + existing := hm.Hosts[vpnAddr] + hm.Hosts[vpnAddr] = hostinfo + + if existing != nil && existing != hostinfo { + hostinfo.next = existing + existing.prev = hostinfo + } i := 1 check := hostinfo @@ -536,12 +602,12 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { } } -func (hm *HostMap) GetPreferredRanges() []*net.IPNet { +func (hm *HostMap) GetPreferredRanges() []netip.Prefix { //NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer return *hm.preferredRanges.Load() } -func (hm *HostMap) ForEachVpnIp(f controlEach) { +func (hm *HostMap) ForEachVpnAddr(f controlEach) { hm.RLock() defer hm.RUnlock() @@ -561,14 +627,14 @@ func (hm *HostMap) ForEachIndex(f controlEach) { // TryPromoteBest handles re-querying lighthouses and probing for better paths // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients! -func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { +func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interface) { c := i.promoteCounter.Add(1) if c%ifce.tryPromoteEvery.Load() == 0 { remote := i.remote // return early if we are already on a preferred remote - if remote != nil { - rIP := remote.IP + if remote.IsValid() { + rIP := remote.Addr() for _, l := range preferredRanges { if l.Contains(rIP) { return @@ -576,8 +642,8 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) } } - i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) { - if remote != nil && (addr == nil || !preferred) { + i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) { + if remote.IsValid() && (!addr.IsValid() || !preferred) { return } @@ -595,34 +661,34 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) } i.nextLHQuery.Store(now + ifce.reQueryWait.Load()) - ifce.lightHouse.QueryServer(i.vpnIp) + ifce.lightHouse.QueryServer(i.vpnAddrs[0]) } } -func (i *HostInfo) GetCert() *cert.NebulaCertificate { +func (i *HostInfo) GetCert() *cert.CachedCertificate { if i.ConnectionState != nil { return i.ConnectionState.peerCert } return nil } -func (i *HostInfo) SetRemote(remote *udp.Addr) { +func (i *HostInfo) SetRemote(remote netip.AddrPort) { // We copy here because we likely got this remote from a source that reuses the object - if !i.remote.Equals(remote) { - i.remote = remote.Copy() - i.remotes.LearnRemote(i.vpnIp, remote.Copy()) + if i.remote != remote { + i.remote = remote + i.remotes.LearnRemote(i.vpnAddrs[0], remote) } } // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam // time on the HostInfo will also be updated. -func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { - if newRemote == nil { +func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool { + if !newRemote.IsValid() { // relays have nil udp Addrs return false } currentRemote := i.remote - if currentRemote == nil { + if !currentRemote.IsValid() { i.SetRemote(newRemote) return true } @@ -632,11 +698,11 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { newIsPreferred := false for _, l := range hm.GetPreferredRanges() { // return early if we are already on a preferred remote - if l.Contains(currentRemote.IP) { + if l.Contains(currentRemote.Addr()) { return false } - if l.Contains(newRemote.IP) { + if l.Contains(newRemote.Addr()) { newIsPreferred = true } } @@ -644,7 +710,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { if newIsPreferred { // Consider this a roaming event i.lastRoam = time.Now() - i.lastRoamRemote = currentRemote.Copy() + i.lastRoamRemote = currentRemote i.SetRemote(newRemote) @@ -661,21 +727,20 @@ func (i *HostInfo) RecvErrorExceeded() bool { return true } -func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { - if len(c.Details.Ips) == 1 && len(c.Details.Subnets) == 0 { +func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) { + if len(networks) == 1 && len(unsafeNetworks) == 0 { // Simple case, no CIDRTree needed return } - remoteCidr := cidr.NewTree4[struct{}]() - for _, ip := range c.Details.Ips { - remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) + i.networks = new(bart.Table[struct{}]) + for _, network := range networks { + i.networks.Insert(network, struct{}{}) } - for _, n := range c.Details.Subnets { - remoteCidr.AddCIDR(n, struct{}{}) + for _, network := range unsafeNetworks { + i.networks.Insert(network, struct{}{}) } - i.remoteCidr = remoteCidr } func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { @@ -683,13 +748,13 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { return logrus.NewEntry(l) } - li := l.WithField("vpnIp", i.vpnIp). + li := l.WithField("vpnAddrs", i.vpnAddrs). WithField("localIndex", i.localIndexId). WithField("remoteIndex", i.remoteIndexId) if connState := i.ConnectionState; connState != nil { if peerCert := connState.peerCert; peerCert != nil { - li = li.WithField("certName", peerCert.Details.Name) + li = li.WithField("certName", peerCert.Certificate.Name()) } } @@ -698,9 +763,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { // Utility functions -func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP { +func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { //FIXME: This function is pretty garbage - var ips []net.IP + var finalAddrs []netip.Addr ifaces, _ := net.Interfaces() for _, i := range ifaces { allow := allowList.AllowName(i.Name) @@ -712,30 +777,36 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP { continue } addrs, _ := i.Addrs() - for _, addr := range addrs { - var ip net.IP - switch v := addr.(type) { + for _, rawAddr := range addrs { + var addr netip.Addr + switch v := rawAddr.(type) { case *net.IPNet: //continue - ip = v.IP + addr, _ = netip.AddrFromSlice(v.IP) case *net.IPAddr: - ip = v.IP + addr, _ = netip.AddrFromSlice(v.IP) } - //TODO: Filtering out link local for now, this is probably the most correct thing - //TODO: Would be nice to filter out SLAAC MAC based ips as well - if ip.IsLoopback() == false && !ip.IsLinkLocalUnicast() { - allow := allowList.Allow(ip) - if l.Level >= logrus.TraceLevel { - l.WithField("localIp", ip).WithField("allow", allow).Trace("localAllowList.Allow") + if !addr.IsValid() { + if l.Level >= logrus.DebugLevel { + l.WithField("localAddr", rawAddr).Debug("addr was invalid") } - if !allow { + continue + } + addr = addr.Unmap() + + if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false { + isAllowed := allowList.Allow(addr) + if l.Level >= logrus.TraceLevel { + l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow") + } + if !isAllowed { continue } - ips = append(ips, ip) + finalAddrs = append(finalAddrs, addr) } } } - return &ips + return finalAddrs } diff --git a/hostmap_test.go b/hostmap_test.go index 8311cef..b3580cf 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -1,7 +1,7 @@ package nebula import ( - "net" + "net/netip" "testing" "github.com/slackhq/nebula/config" @@ -11,20 +11,14 @@ import ( func TestHostMap_MakePrimary(t *testing.T) { l := test.NewLogger() - hm := newHostMap( - l, - &net.IPNet{ - IP: net.IP{10, 0, 0, 1}, - Mask: net.IPMask{255, 255, 255, 0}, - }, - ) + hm := newHostMap(l) f := &Interface{} - h1 := &HostInfo{vpnIp: 1, localIndexId: 1} - h2 := &HostInfo{vpnIp: 1, localIndexId: 2} - h3 := &HostInfo{vpnIp: 1, localIndexId: 3} - h4 := &HostInfo{vpnIp: 1, localIndexId: 4} + h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1} + h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2} + h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3} + h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4} hm.unlockedAddHostInfo(h4, f) hm.unlockedAddHostInfo(h3, f) @@ -32,7 +26,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.unlockedAddHostInfo(h1, f) // Make sure we go h1 -> h2 -> h3 -> h4 - prim := hm.QueryVpnIp(1) + prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -47,7 +41,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h3) // Make sure we go h3 -> h1 -> h2 -> h4 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h3.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -62,7 +56,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -77,7 +71,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -91,22 +85,16 @@ func TestHostMap_MakePrimary(t *testing.T) { func TestHostMap_DeleteHostInfo(t *testing.T) { l := test.NewLogger() - hm := newHostMap( - l, - &net.IPNet{ - IP: net.IP{10, 0, 0, 1}, - Mask: net.IPMask{255, 255, 255, 0}, - }, - ) + hm := newHostMap(l) f := &Interface{} - h1 := &HostInfo{vpnIp: 1, localIndexId: 1} - h2 := &HostInfo{vpnIp: 1, localIndexId: 2} - h3 := &HostInfo{vpnIp: 1, localIndexId: 3} - h4 := &HostInfo{vpnIp: 1, localIndexId: 4} - h5 := &HostInfo{vpnIp: 1, localIndexId: 5} - h6 := &HostInfo{vpnIp: 1, localIndexId: 6} + h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1} + h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2} + h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3} + h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4} + h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5} + h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6} hm.unlockedAddHostInfo(h6, f) hm.unlockedAddHostInfo(h5, f) @@ -122,7 +110,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h) // Make sure we go h1 -> h2 -> h3 -> h4 -> h5 - prim := hm.QueryVpnIp(1) + prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -141,7 +129,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h1.next) // Make sure we go h2 -> h3 -> h4 -> h5 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -159,7 +147,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h3.next) // Make sure we go h2 -> h4 -> h5 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -175,7 +163,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h5.next) // Make sure we go h2 -> h4 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -189,7 +177,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h2.next) // Make sure we only have h4 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Nil(t, prim.prev) assert.Nil(t, prim.next) @@ -201,7 +189,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h4.next) // Make sure we have nil - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Nil(t, prim) } @@ -209,16 +197,9 @@ func TestHostMap_reload(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - hm := NewHostMapFromConfig( - l, - &net.IPNet{ - IP: net.IP{10, 0, 0, 1}, - Mask: net.IPMask{255, 255, 255, 0}, - }, - c, - ) + hm := NewHostMapFromConfig(l, c) - toS := func(ipn []*net.IPNet) []string { + toS := func(ipn []netip.Prefix) []string { var s []string for _, n := range ipn { s = append(s, n.String()) @@ -229,8 +210,8 @@ func TestHostMap_reload(t *testing.T) { assert.Empty(t, hm.GetPreferredRanges()) c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]") - assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges())) + assert.Equal(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges())) c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]") - assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges())) + assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges())) } diff --git a/hostmap_tester.go b/hostmap_tester.go index 0d5d41b..fe40c53 100644 --- a/hostmap_tester.go +++ b/hostmap_tester.go @@ -5,10 +5,12 @@ package nebula // This file contains functions used to export information to the e2e testing framework -import "github.com/slackhq/nebula/iputil" +import ( + "net/netip" +) -func (i *HostInfo) GetVpnIp() iputil.VpnIp { - return i.vpnIp +func (i *HostInfo) GetVpnAddrs() []netip.Addr { + return i.vpnAddrs } func (i *HostInfo) GetLocalIndex() uint32 { diff --git a/inside.go b/inside.go index 079e4dd..0af350d 100644 --- a/inside.go +++ b/inside.go @@ -1,12 +1,14 @@ package nebula import ( + "net/netip" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" - "github.com/slackhq/nebula/udp" + "github.com/slackhq/nebula/routing" ) func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { @@ -19,14 +21,18 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } // Ignore local broadcast packets - if f.dropLocalBroadcast && fwPacket.RemoteIP == f.localBroadcast { - return + if f.dropLocalBroadcast { + _, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr) + if found { + return + } } - if fwPacket.RemoteIP == f.myVpnIp { + _, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr) + if found { // Immediately forward packets from self to self. // This should only happen on Darwin-based and FreeBSD hosts, which - // routes packets from the Nebula IP to the Nebula IP through the Nebula + // routes packets from the Nebula addr to the Nebula addr through the Nebula // TUN device. if immediatelyForwardToSelf { _, err := f.readers[q].Write(packet) @@ -35,25 +41,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } } // Otherwise, drop. On linux, we should never see these packets - Linux - // routes packets from the nebula IP to the nebula IP through the loopback device. + // routes packets from the nebula addr to the nebula addr through the loopback device. return } - // Ignore broadcast packets - if f.dropMulticast && isMulticast(fwPacket.RemoteIP) { + // Ignore multicast packets + if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() { return } - hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) { + hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) }) if hostinfo == nil { f.rejectInside(packet, out, q) if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", fwPacket.RemoteIP). + f.l.WithField("vpnAddr", fwPacket.RemoteAddr). WithField("fwPacket", fwPacket). - Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes") + Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") } return } @@ -64,7 +70,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q) + f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) } else { f.rejectInside(packet, out, q) @@ -113,24 +119,97 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * return } - f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q) + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) } -func (f *Interface) Handshake(vpnIp iputil.VpnIp) { - f.getOrHandshake(vpnIp, nil) +// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established +func (f *Interface) Handshake(vpnAddr netip.Addr) { + f.getOrHandshakeNoRouting(vpnAddr, nil) } -// getOrHandshake returns nil if the vpnIp is not routable. +// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel -func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { - if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) { - vpnIp = f.inside.RouteFor(vpnIp) - if vpnIp == 0 { - return nil, false - } +func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { + _, found := f.myVpnNetworksTable.Lookup(vpnAddr) + if found { + return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback) + } + + return nil, false +} + +// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary. +// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel. +func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { + + destinationAddr := fwPacket.RemoteAddr + + hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback) + + // Host is inside the mesh, no routing required + if hostinfo != nil { + return hostinfo, ready + } + + gateways := f.inside.RoutesFor(destinationAddr) + + switch len(gateways) { + case 0: + return nil, false + case 1: + // Single gateway route + return f.handshakeManager.GetOrHandshake(gateways[0].Addr(), cacheCallback) + default: + // Multi gateway route, perform ECMP categorization + gatewayAddr, balancingOk := routing.BalancePacket(fwPacket, gateways) + + if !balancingOk { + // This happens if the gateway buckets were not calculated, this _should_ never happen + f.l.Error("Gateway buckets not calculated, fallback from ECMP to random routing. Please report this bug.") + } + + var handshakeInfoForChosenGateway *HandshakeHostInfo + var hhReceiver = func(hh *HandshakeHostInfo) { + handshakeInfoForChosenGateway = hh + } + + // Store the handshakeHostInfo for later. + // If this node is not reachable we will attempt other nodes, if none are reachable we will + // cache the packet for this gateway. + if hostinfo, ready = f.handshakeManager.GetOrHandshake(gatewayAddr, hhReceiver); ready { + return hostinfo, true + } + + // It appears the selected gateway cannot be reached, find another gateway to fallback on. + // The current implementation breaks ECMP but that seems better than no connectivity. + // If ECMP is also required when a gateway is down then connectivity status + // for each gateway needs to be kept and the weights recalculated when they go up or down. + // This would also need to interact with unsafe_route updates through reloading the config or + // use of the use_system_route_table option + + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("destination", destinationAddr). + WithField("originalGateway", gatewayAddr). + Debugln("Calculated gateway for ECMP not available, attempting other gateways") + } + + for i := range gateways { + // Skip the gateway that failed previously + if gateways[i].Addr() == gatewayAddr { + continue + } + + // We do not need the HandshakeHostInfo since we cache the packet in the originally chosen gateway + if hostinfo, ready = f.handshakeManager.GetOrHandshake(gateways[i].Addr(), nil); ready { + return hostinfo, true + } + } + + // No gateways reachable, cache the packet in the originally chosen gateway + cacheCallback(handshakeInfoForChosenGateway) + return hostinfo, false } - return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback) } func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { @@ -152,19 +231,19 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp return } - f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, nil, p, nb, out, 0) + f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) } -// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp -func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { - hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) { +// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr +func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { + hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) if hostInfo == nil { if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", vpnIp). - Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes") + f.l.WithField("vpnAddr", vpnAddr). + Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes") } return } @@ -182,10 +261,10 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) { f.messageMetrics.Tx(t, st, 1) - f.sendNoMetrics(t, st, ci, hostinfo, nil, p, nb, out, 0) + f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0) } -func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) { +func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) { f.messageMetrics.Tx(t, st, 1) f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) } @@ -255,12 +334,11 @@ func (f *Interface) SendVia(via *HostInfo, f.connectionManager.RelayUsed(relay.LocalIndex) } -func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) { +func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) { if ci.eKey == nil { - //TODO: log warning return } - useRelay := remote == nil && hostinfo.remote == nil + useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() fullOut := out if useRelay { @@ -284,14 +362,14 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType f.connectionManager.Out(hostinfo.localIndexId) // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against - // all our IPs and enable a faster roaming. + // all our addrs and enable a faster roaming. if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount { //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. - f.lightHouse.QueryServer(hostinfo.vpnIp) + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) hostinfo.lastRebindCount = f.rebindCount if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter") + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") } } @@ -308,13 +386,13 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType return } - if remote != nil { + if remote.IsValid() { err = f.writers[q].WriteTo(out, remote) if err != nil { hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).Error("Failed to write outgoing packet") } - } else if hostinfo.remote != nil { + } else if hostinfo.remote.IsValid() { err = f.writers[q].WriteTo(out, hostinfo.remote) if err != nil { hostinfo.logger(f.l).WithError(err). @@ -323,7 +401,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } else { // Try to send via a relay for _, relayIP := range hostinfo.relayState.CopyRelayIps() { - relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP) + relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) if err != nil { hostinfo.relayState.DeleteRelay(relayIP) hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") @@ -334,8 +412,3 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } } } - -func isMulticast(ip iputil.VpnIp) bool { - // Class D multicast - return (((ip >> 24) & 0xff) & 0xf0) == 0xe0 -} diff --git a/interface.go b/interface.go index d16348a..21e198c 100644 --- a/interface.go +++ b/interface.go @@ -5,18 +5,18 @@ import ( "errors" "fmt" "io" - "net" + "net/netip" "os" "runtime" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" ) @@ -28,7 +28,6 @@ type InterfaceConfig struct { Outside udp.Conn Inside overlay.Device pki *PKI - Cipher string Firewall *Firewall ServeDns bool HandshakeManager *HandshakeManager @@ -52,25 +51,27 @@ type InterfaceConfig struct { } type Interface struct { - hostMap *HostMap - outside udp.Conn - inside overlay.Device - pki *PKI - cipher string - firewall *Firewall - connectionManager *connectionManager - handshakeManager *HandshakeManager - serveDns bool - createTime time.Time - lightHouse *LightHouse - localBroadcast iputil.VpnIp - myVpnIp iputil.VpnIp - dropLocalBroadcast bool - dropMulticast bool - routines int - disconnectInvalid atomic.Bool - closed atomic.Bool - relayManager *relayManager + hostMap *HostMap + outside udp.Conn + inside overlay.Device + pki *PKI + firewall *Firewall + connectionManager *connectionManager + handshakeManager *HandshakeManager + serveDns bool + createTime time.Time + lightHouse *LightHouse + myBroadcastAddrsTable *bart.Table[struct{}] + myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate + myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate + myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate + myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate + dropLocalBroadcast bool + dropMulticast bool + routines int + disconnectInvalid atomic.Bool + closed atomic.Bool + relayManager *relayManager tryPromoteEvery atomic.Uint32 reQueryEvery atomic.Uint32 @@ -102,9 +103,11 @@ type EncWriter interface { out []byte, nocopy bool, ) - SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) + SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) - Handshake(vpnIp iputil.VpnIp) + Handshake(vpnAddr netip.Addr) + GetHostInfo(vpnAddr netip.Addr) *HostInfo + GetCertState() *CertState } type sendRecvErrorConfig uint8 @@ -115,10 +118,10 @@ const ( sendRecvErrorPrivate ) -func (s sendRecvErrorConfig) ShouldSendRecvError(ip net.IP) bool { +func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool { switch s { case sendRecvErrorPrivate: - return ip.IsPrivate() + return endpoint.Addr().IsPrivate() case sendRecvErrorAlways: return true case sendRecvErrorNever: @@ -155,28 +158,29 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { return nil, errors.New("no firewall rules") } - certificate := c.pki.GetCertState().Certificate - myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP) + cs := c.pki.getCertState() ifce := &Interface{ - pki: c.pki, - hostMap: c.HostMap, - outside: c.Outside, - inside: c.Inside, - cipher: c.Cipher, - firewall: c.Firewall, - serveDns: c.ServeDns, - handshakeManager: c.HandshakeManager, - createTime: time.Now(), - lightHouse: c.lightHouse, - localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask), - dropLocalBroadcast: c.DropLocalBroadcast, - dropMulticast: c.DropMulticast, - routines: c.routines, - version: c.version, - writers: make([]udp.Conn, c.routines), - readers: make([]io.ReadWriteCloser, c.routines), - myVpnIp: myVpnIp, - relayManager: c.relayManager, + pki: c.pki, + hostMap: c.HostMap, + outside: c.Outside, + inside: c.Inside, + firewall: c.Firewall, + serveDns: c.ServeDns, + handshakeManager: c.HandshakeManager, + createTime: time.Now(), + lightHouse: c.lightHouse, + dropLocalBroadcast: c.DropLocalBroadcast, + dropMulticast: c.DropMulticast, + routines: c.routines, + version: c.version, + writers: make([]udp.Conn, c.routines), + readers: make([]io.ReadWriteCloser, c.routines), + myVpnNetworks: cs.myVpnNetworks, + myVpnNetworksTable: cs.myVpnNetworksTable, + myVpnAddrs: cs.myVpnAddrs, + myVpnAddrsTable: cs.myVpnAddrsTable, + myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable, + relayManager: c.relayManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, @@ -210,7 +214,7 @@ func (f *Interface) activate() { f.l.WithError(err).Error("Failed to get udp listen address") } - f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()). + f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks). WithField("build", f.version).WithField("udpAddr", addr). WithField("boringcrypto", boringEnabled()). Info("Nebula interface is active") @@ -251,16 +255,22 @@ func (f *Interface) listenOut(i int) { runtime.LockOSThread() var li udp.Conn - // TODO clean this up with a coherent interface for each outside connection if i > 0 { li = f.writers[i] } else { li = f.outside } + ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) - li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i) + plaintext := make([]byte, udp.MTU) + h := &header.H{} + fwPacket := &firewall.Packet{} + nb := make([]byte, 12, 12) + + li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { + f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + }) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { @@ -317,7 +327,7 @@ func (f *Interface) reloadFirewall(c *config.C) { return } - fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c) + fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return @@ -400,6 +410,8 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { udpStats := udp.NewUDPStatsEmitter(f.writers) certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) + certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil) + certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil) for { select { @@ -409,11 +421,30 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { f.firewall.EmitStats() f.handshakeManager.EmitStats() udpStats() - certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second)) + + certState := f.pki.getCertState() + defaultCrt := certState.GetDefaultCertificate() + certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second)) + certDefaultVersion.Update(int64(defaultCrt.Version())) + + // Report the max certificate version we are capable of using + if certState.v2Cert != nil { + certMaxVersion.Update(int64(certState.v2Cert.Version())) + } else { + certMaxVersion.Update(int64(certState.v1Cert.Version())) + } } } } +func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo { + return f.hostMap.QueryVpnAddr(vpnIp) +} + +func (f *Interface) GetCertState() *CertState { + return f.pki.getCertState() +} + func (f *Interface) Close() error { f.closed.Store(true) diff --git a/iputil/util.go b/iputil/util.go deleted file mode 100644 index 65f7677..0000000 --- a/iputil/util.go +++ /dev/null @@ -1,93 +0,0 @@ -package iputil - -import ( - "encoding/binary" - "fmt" - "net" - "net/netip" -) - -type VpnIp uint32 - -const maxIPv4StringLen = len("255.255.255.255") - -func (ip VpnIp) String() string { - b := make([]byte, maxIPv4StringLen) - - n := ubtoa(b, 0, byte(ip>>24)) - b[n] = '.' - n++ - - n += ubtoa(b, n, byte(ip>>16&255)) - b[n] = '.' - n++ - - n += ubtoa(b, n, byte(ip>>8&255)) - b[n] = '.' - n++ - - n += ubtoa(b, n, byte(ip&255)) - return string(b[:n]) -} - -func (ip VpnIp) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil -} - -func (ip VpnIp) ToIP() net.IP { - nip := make(net.IP, 4) - binary.BigEndian.PutUint32(nip, uint32(ip)) - return nip -} - -func (ip VpnIp) ToNetIpAddr() netip.Addr { - var nip [4]byte - binary.BigEndian.PutUint32(nip[:], uint32(ip)) - return netip.AddrFrom4(nip) -} - -func Ip2VpnIp(ip []byte) VpnIp { - if len(ip) == 16 { - return VpnIp(binary.BigEndian.Uint32(ip[12:16])) - } - return VpnIp(binary.BigEndian.Uint32(ip)) -} - -func ToNetIpAddr(ip net.IP) (netip.Addr, error) { - addr, ok := netip.AddrFromSlice(ip) - if !ok { - return netip.Addr{}, fmt.Errorf("invalid net.IP: %v", ip) - } - return addr, nil -} - -func ToNetIpPrefix(ipNet net.IPNet) (netip.Prefix, error) { - addr, err := ToNetIpAddr(ipNet.IP) - if err != nil { - return netip.Prefix{}, err - } - ones, bits := ipNet.Mask.Size() - if ones == 0 && bits == 0 { - return netip.Prefix{}, fmt.Errorf("invalid net.IP: %v", ipNet) - } - return netip.PrefixFrom(addr, ones), nil -} - -// ubtoa encodes the string form of the integer v to dst[start:] and -// returns the number of bytes written to dst. The caller must ensure -// that dst has sufficient length. -func ubtoa(dst []byte, start int, v byte) int { - if v < 10 { - dst[start] = v + '0' - return 1 - } else if v < 100 { - dst[start+1] = v%10 + '0' - dst[start] = v/10 + '0' - return 2 - } - - dst[start+2] = v%10 + '0' - dst[start+1] = (v/10)%10 + '0' - dst[start] = v/100 + '0' - return 3 -} diff --git a/iputil/util_test.go b/iputil/util_test.go deleted file mode 100644 index 712d426..0000000 --- a/iputil/util_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package iputil - -import ( - "net" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestVpnIp_String(t *testing.T) { - assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String()) - assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String()) - assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String()) - assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String()) - assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String()) - assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String()) -} diff --git a/lighthouse.go b/lighthouse.go index bdabbd7..c87af90 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -7,43 +7,37 @@ import ( "fmt" "net" "net/netip" + "slices" + "strconv" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" ) -//TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake? -//TODO: nodes are roaming lighthouses, this is bad. How are they learning? - var ErrHostNotKnown = errors.New("host not known") -type netIpAndPort struct { - ip net.IP - port uint16 -} - type LightHouse struct { - //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time + //TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time syncRWMutex //Because we concurrently read and write to our maps ctx context.Context amLighthouse bool - myVpnIp iputil.VpnIp - myVpnZeros iputil.VpnIp - myVpnNet *net.IPNet - punchConn udp.Conn - punchy *Punchy + + myVpnNetworks []netip.Prefix + myVpnNetworksTable *bart.Table[struct{}] + punchConn udp.Conn + punchy *Punchy // Local cache of answers from light houses - // map of vpn Ip to answers - addrMap map[iputil.VpnIp]*RemoteList + // map of vpn addr to answers + addrMap map[netip.Addr]*RemoteList // filters remote addresses allowed for each host // - When we are a lighthouse, this filters what addresses we store and @@ -56,26 +50,26 @@ type LightHouse struct { localAllowList atomic.Pointer[LocalAllowList] // used to trigger the HandshakeManager when we receive HostQueryReply - handshakeTrigger chan<- iputil.VpnIp + handshakeTrigger chan<- netip.Addr // staticList exists to avoid having a bool in each addrMap entry // since static should be rare - staticList atomic.Pointer[map[iputil.VpnIp]struct{}] - lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}] + staticList atomic.Pointer[map[netip.Addr]struct{}] + lighthouses atomic.Pointer[map[netip.Addr]struct{}] interval atomic.Int64 updateCancel context.CancelFunc ifce EncWriter nebulaPort uint32 // 32 bits because protobuf does not have a uint16 - advertiseAddrs atomic.Pointer[[]netIpAndPort] + advertiseAddrs atomic.Pointer[[]netip.AddrPort] - // IP's of relays that can be used by peers to access me - relaysForMe atomic.Pointer[[]iputil.VpnIp] + // Addr's of relays that can be used by peers to access me + relaysForMe atomic.Pointer[[]netip.Addr] - queryChan chan iputil.VpnIp + queryChan chan netip.Addr - calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote + calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote metrics *MessageMetrics metricHolepunchTx metrics.Counter @@ -84,7 +78,7 @@ type LightHouse struct { // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -97,27 +91,25 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, if err != nil { return nil, util.NewContextualError("Failed to get listening port", nil, err) } - nebulaPort = uint32(uPort.Port) + nebulaPort = uint32(uPort.Port()) } - ones, _ := myVpnNet.Mask.Size() h := LightHouse{ - syncRWMutex: newSyncRWMutex("lighthouse"), - ctx: ctx, - amLighthouse: amLighthouse, - myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP), - myVpnZeros: iputil.VpnIp(32 - ones), - myVpnNet: myVpnNet, - addrMap: make(map[iputil.VpnIp]*RemoteList), - nebulaPort: nebulaPort, - punchConn: pc, - punchy: p, - queryChan: make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)), - l: l, + syncRWMutex: newSyncRWMutex("lighthouse"), + ctx: ctx, + amLighthouse: amLighthouse, + myVpnNetworks: cs.myVpnNetworks, + myVpnNetworksTable: cs.myVpnNetworksTable, + addrMap: make(map[netip.Addr]*RemoteList), + nebulaPort: nebulaPort, + punchConn: pc, + punchy: p, + queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), + l: l, } - lighthouses := make(map[iputil.VpnIp]struct{}) + lighthouses := make(map[netip.Addr]struct{}) h.lighthouses.Store(&lighthouses) - staticList := make(map[iputil.VpnIp]struct{}) + staticList := make(map[netip.Addr]struct{}) h.staticList.Store(&staticList) if c.GetBool("stats.lighthouse_metrics", false) { @@ -147,11 +139,11 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, return &h, nil } -func (lh *LightHouse) GetStaticHostList() map[iputil.VpnIp]struct{} { +func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} { return *lh.staticList.Load() } -func (lh *LightHouse) GetLighthouses() map[iputil.VpnIp]struct{} { +func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} { return *lh.lighthouses.Load() } @@ -163,15 +155,15 @@ func (lh *LightHouse) GetLocalAllowList() *LocalAllowList { return lh.localAllowList.Load() } -func (lh *LightHouse) GetAdvertiseAddrs() []netIpAndPort { +func (lh *LightHouse) GetAdvertiseAddrs() []netip.AddrPort { return *lh.advertiseAddrs.Load() } -func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp { +func (lh *LightHouse) GetRelaysForMe() []netip.Addr { return *lh.relaysForMe.Load() } -func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] { +func (lh *LightHouse) getCalculatedRemotes() *bart.Table[[]*calculatedRemote] { return lh.calculatedRemotes.Load() } @@ -182,25 +174,41 @@ func (lh *LightHouse) GetUpdateInterval() int64 { func (lh *LightHouse) reload(c *config.C, initial bool) error { if initial || c.HasChanged("lighthouse.advertise_addrs") { rawAdvAddrs := c.GetStringSlice("lighthouse.advertise_addrs", []string{}) - advAddrs := make([]netIpAndPort, 0) + advAddrs := make([]netip.AddrPort, 0) for i, rawAddr := range rawAdvAddrs { - fIp, fPort, err := udp.ParseIPAndPort(rawAddr) + host, sport, err := net.SplitHostPort(rawAddr) if err != nil { return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) } - if fPort == 0 { - fPort = uint16(lh.nebulaPort) + addrs, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host) + if err != nil { + return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) + } + if len(addrs) == 0 { + return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, nil) } - if ip4 := fIp.To4(); ip4 != nil && lh.myVpnNet.Contains(fIp) { + port, err := strconv.Atoi(sport) + if err != nil { + return util.NewContextualError("Unable to parse port in lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) + } + + if port == 0 { + port = int(lh.nebulaPort) + } + + //TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used + addr := addrs[0].Unmap() + _, found := lh.myVpnNetworksTable.Lookup(addr) + if found { lh.l.WithField("addr", rawAddr).WithField("entry", i+1). Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") continue } - advAddrs = append(advAddrs, netIpAndPort{ip: fIp, port: fPort}) + advAddrs = append(advAddrs, netip.AddrPortFrom(addr, uint16(port))) } lh.advertiseAddrs.Store(&advAddrs) @@ -233,7 +241,6 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.remoteAllowList.Store(ral) if !initial { - //TODO: a diff will be annoyingly difficult lh.l.Info("lighthouse.remote_allow_list and/or lighthouse.remote_allow_ranges has changed") } } @@ -246,7 +253,6 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.localAllowList.Store(lal) if !initial { - //TODO: a diff will be annoyingly difficult lh.l.Info("lighthouse.local_allow_list has changed") } } @@ -259,7 +265,6 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.calculatedRemotes.Store(cr) if !initial { - //TODO: a diff will be annoyingly difficult lh.l.Info("lighthouse.calculated_remotes has changed") } } @@ -270,23 +275,22 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { // Entries no longer present must have their (possible) background DNS goroutines stopped. if existingStaticList := lh.staticList.Load(); existingStaticList != nil { lh.RLock() - for staticVpnIp := range *existingStaticList { - if am, ok := lh.addrMap[staticVpnIp]; ok && am != nil { + for staticVpnAddr := range *existingStaticList { + if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil { am.hr.Cancel() } } lh.RUnlock() } // Build a new list based on current config. - staticList := make(map[iputil.VpnIp]struct{}) - err := lh.loadStaticMap(c, lh.myVpnNet, staticList) + staticList := make(map[netip.Addr]struct{}) + err := lh.loadStaticMap(c, staticList) if err != nil { return err } lh.staticList.Store(&staticList) if !initial { - //TODO: we should remove any remote list entries for static hosts that were removed/modified? if c.HasChanged("static_host_map") { lh.l.Info("static_host_map has changed") } @@ -303,8 +307,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } if initial || c.HasChanged("lighthouse.hosts") { - lhMap := make(map[iputil.VpnIp]struct{}) - err := lh.parseLighthouses(c, lh.myVpnNet, lhMap) + lhMap := make(map[netip.Addr]struct{}) + err := lh.parseLighthouses(c, lhMap) if err != nil { return err } @@ -323,16 +327,17 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { if len(c.GetStringSlice("relay.relays", nil)) > 0 { lh.l.Info("Ignoring relays from config because am_relay is true") } - relaysForMe := []iputil.VpnIp{} + relaysForMe := []netip.Addr{} lh.relaysForMe.Store(&relaysForMe) case false: - relaysForMe := []iputil.VpnIp{} + relaysForMe := []netip.Addr{} for _, v := range c.GetStringSlice("relay.relays", nil) { - lh.l.WithField("relay", v).Info("Read relay from config") - - configRIP := net.ParseIP(v) - if configRIP != nil { - relaysForMe = append(relaysForMe, iputil.Ip2VpnIp(configRIP)) + configRIP, err := netip.ParseAddr(v) + if err != nil { + lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed") + } else { + lh.l.WithField("relay", v).Info("Read relay from config") + relaysForMe = append(relaysForMe, configRIP) } } lh.relaysForMe.Store(&relaysForMe) @@ -342,21 +347,23 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { return nil } -func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap map[iputil.VpnIp]struct{}) error { +func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error { lhs := c.GetStringSlice("lighthouse.hosts", []string{}) if lh.amLighthouse && len(lhs) != 0 { lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") } for i, host := range lhs { - ip := net.ParseIP(host) - if ip == nil { - return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) + addr, err := netip.ParseAddr(host) + if err != nil { + return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) } - if !tunCidr.Contains(ip) { - return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) + + _, found := lh.myVpnNetworksTable.Lookup(addr) + if !found { + return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil) } - lhMap[iputil.Ip2VpnIp(ip)] = struct{}{} + lhMap[addr] = struct{}{} } if !lh.amLighthouse && len(lhMap) == 0 { @@ -364,9 +371,9 @@ func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap ma } staticList := lh.GetStaticHostList() - for lhIP, _ := range lhMap { - if _, ok := staticList[lhIP]; !ok { - return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhIP) + for lhAddr, _ := range lhMap { + if _, ok := staticList[lhAddr]; !ok { + return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr) } } @@ -399,7 +406,7 @@ func getStaticMapNetwork(c *config.C) (string, error) { return network, nil } -func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error { +func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struct{}) error { d, err := getStaticMapCadence(c) if err != nil { return err @@ -410,35 +417,35 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList return err } - lookup_timeout, err := getStaticMapLookupTimeout(c) + lookupTimeout, err := getStaticMapLookupTimeout(c) if err != nil { return err } - shm := c.GetMap("static_host_map", map[interface{}]interface{}{}) + shm := c.GetMap("static_host_map", map[string]any{}) i := 0 for k, v := range shm { - rip := net.ParseIP(fmt.Sprintf("%v", k)) - if rip == nil { - return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, nil) + vpnAddr, err := netip.ParseAddr(fmt.Sprintf("%v", k)) + if err != nil { + return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err) } - if !tunCidr.Contains(rip) { - return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": rip, "network": tunCidr.String(), "entry": i + 1}, nil) + _, found := lh.myVpnNetworksTable.Lookup(vpnAddr) + if !found { + return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil) } - vpnIp := iputil.Ip2VpnIp(rip) - vals, ok := v.([]interface{}) + vals, ok := v.([]any) if !ok { - vals = []interface{}{v} + vals = []any{v} } remoteAddrs := []string{} for _, v := range vals { remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v)) } - err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList) + err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnAddr, remoteAddrs, staticList) if err != nil { return err } @@ -448,12 +455,12 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList return nil } -func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList { - if !lh.IsLighthouseIP(ip) { - lh.QueryServer(ip) +func (lh *LightHouse) Query(vpnAddr netip.Addr) *RemoteList { + if !lh.IsLighthouseAddr(vpnAddr) { + lh.QueryServer(vpnAddr) } lh.RLock() - if v, ok := lh.addrMap[ip]; ok { + if v, ok := lh.addrMap[vpnAddr]; ok { lh.RUnlock() return v } @@ -462,19 +469,19 @@ func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList { } // QueryServer is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip iputil.VpnIp) { - // Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses - if lh.amLighthouse || lh.IsLighthouseIP(ip) { +func (lh *LightHouse) QueryServer(vpnAddr netip.Addr) { + // Don't put lighthouse addrs in the query channel because we can't query lighthouses about lighthouses + if lh.amLighthouse || lh.IsLighthouseAddr(vpnAddr) { return } chanDebugSend("lighthouse-query-chan") - lh.queryChan <- ip + lh.queryChan <- vpnAddr } -func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList { +func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList { lh.RLock() - if v, ok := lh.addrMap[ip]; ok { + if v, ok := lh.addrMap[vpnAddrs[0]]; ok { lh.RUnlock() return v } @@ -483,24 +490,27 @@ func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList { lh.Lock() defer lh.Unlock() // Add an entry if we don't already have one - return lh.unlockedGetRemoteList(ip) + return lh.unlockedGetRemoteList(vpnAddrs) } // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing -// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp +// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnAddr // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo() -func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) { +func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (int, error)) (bool, int, error) { lh.RLock() // Do we have an entry in the main cache? - if v, ok := lh.addrMap[vpnIp]; ok { + if v, ok := lh.addrMap[vpnAddr]; ok { // Swap lh lock for remote list lock v.RLock() defer v.RUnlock() lh.RUnlock() - // vpnIp should also be the owner here since we are a lighthouse. - c := v.cache[vpnIp] + // We may be asking about a non primary address so lets get the primary address + if slices.Contains(v.vpnAddrs, vpnAddr) { + vpnAddr = v.vpnAddrs[0] + } + c := v.cache[vpnAddr] // Make sure we have if c != nil { n, err := f(c) @@ -512,150 +522,161 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (in return false, 0, nil } -func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { +func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) { // First we check the static mapping // and do nothing if it is there - if _, ok := lh.GetStaticHostList()[vpnIp]; ok { + if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok { return } lh.Lock() - //l.Debugln(lh.addrMap) - delete(lh.addrMap, vpnIp) - - if lh.l.Level >= logrus.DebugLevel { - lh.l.Debugf("deleting %s from lighthouse.", vpnIp) + rm, ok := lh.addrMap[allVpnAddrs[0]] + if ok { + for _, addr := range allVpnAddrs { + srm := lh.addrMap[addr] + if srm == rm { + delete(lh.addrMap, addr) + if lh.l.Level >= logrus.DebugLevel { + lh.l.Debugf("deleting %s from lighthouse.", addr) + } + } + } } - lh.Unlock() } -// AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner +// AddStaticRemote adds a static host entry for vpnAddr as ourselves as the owner // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it -func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error { +func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnAddr netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error { lh.Lock() - am := lh.unlockedGetRemoteList(vpnIp) + am := lh.unlockedGetRemoteList([]netip.Addr{vpnAddr}) am.Lock() defer am.Unlock() ctx := lh.ctx lh.Unlock() hr, err := NewHostnameResults(ctx, lh.l, d, network, timeout, toAddrs, func() { - // This callback runs whenever the DNS hostname resolver finds a different set of IP's + // This callback runs whenever the DNS hostname resolver finds a different set of addr's // in its resolution for hostnames. am.Lock() defer am.Unlock() am.shouldRebuild = true }) if err != nil { - return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) + return util.NewContextualError("Static host address could not be parsed", m{"vpnAddr": vpnAddr, "entry": i + 1}, err) } am.unlockedSetHostnamesResults(hr) - for _, addrPort := range hr.GetIPs() { - + for _, addrPort := range hr.GetAddrs() { + if !lh.shouldAdd(vpnAddr, addrPort.Addr()) { + continue + } switch { case addrPort.Addr().Is4(): - to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) - if !lh.unlockedShouldAddV4(vpnIp, to) { - continue - } - am.unlockedPrependV4(lh.myVpnIp, to) + am.unlockedPrependV4(lh.myVpnNetworks[0].Addr(), netAddrToProtoV4AddrPort(addrPort.Addr(), addrPort.Port())) case addrPort.Addr().Is6(): - to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) - if !lh.unlockedShouldAddV6(vpnIp, to) { - continue - } - am.unlockedPrependV6(lh.myVpnIp, to) + am.unlockedPrependV6(lh.myVpnNetworks[0].Addr(), netAddrToProtoV6AddrPort(addrPort.Addr(), addrPort.Port())) } } // Mark it as static in the caller provided map - staticList[vpnIp] = struct{}{} + staticList[vpnAddr] = struct{}{} return nil } // addCalculatedRemotes adds any calculated remotes based on the // lighthouse.calculated_remotes configuration. It returns true if any // calculated remotes were added -func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { +func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool { tree := lh.getCalculatedRemotes() if tree == nil { return false } - ok, calculatedRemotes := tree.MostSpecificContains(vpnIp) + calculatedRemotes, ok := tree.Lookup(vpnAddr) if !ok { return false } - var calculated []*Ip4AndPort + var calculatedV4 []*V4AddrPort + var calculatedV6 []*V6AddrPort for _, cr := range calculatedRemotes { - c := cr.Apply(vpnIp) - if c != nil { - calculated = append(calculated, c) + if vpnAddr.Is4() { + c := cr.ApplyV4(vpnAddr) + if c != nil { + calculatedV4 = append(calculatedV4, c) + } + } else if vpnAddr.Is6() { + c := cr.ApplyV6(vpnAddr) + if c != nil { + calculatedV6 = append(calculatedV6, c) + } } } lh.Lock() - am := lh.unlockedGetRemoteList(vpnIp) + am := lh.unlockedGetRemoteList([]netip.Addr{vpnAddr}) am.Lock() defer am.Unlock() lh.Unlock() - am.unlockedSetV4(lh.myVpnIp, vpnIp, calculated, lh.unlockedShouldAddV4) + if len(calculatedV4) > 0 { + am.unlockedSetV4(lh.myVpnNetworks[0].Addr(), vpnAddr, calculatedV4, lh.unlockedShouldAddV4) + } - return len(calculated) > 0 + if len(calculatedV6) > 0 { + am.unlockedSetV6(lh.myVpnNetworks[0].Addr(), vpnAddr, calculatedV6, lh.unlockedShouldAddV6) + } + + return len(calculatedV4) > 0 || len(calculatedV6) > 0 } -// unlockedGetRemoteList assumes you have the lh lock -func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { - am, ok := lh.addrMap[vpnIp] +// unlockedGetRemoteList +// assumes you have the lh lock +func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList { + am, ok := lh.addrMap[allAddrs[0]] if !ok { - am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) }) - lh.addrMap[vpnIp] = am + am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) }) + for _, addr := range allAddrs { + lh.addrMap[addr] = am + } } return am } -func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool { - switch { - case to.Is4(): - ipBytes := to.As4() - ip := iputil.Ip2VpnIp(ipBytes[:]) - allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") - } - if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) { - return false - } - case to.Is6(): - ipBytes := to.As16() - - hi := binary.BigEndian.Uint64(ipBytes[:8]) - lo := binary.BigEndian.Uint64(ipBytes[8:]) - allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow") - } - - // We don't check our vpn network here because nebula does not support ipv6 on the inside - if !allow { - return false - } +func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool { + allow := lh.GetRemoteAllowList().Allow(vpnAddr, to) + if lh.l.Level >= logrus.TraceLevel { + lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow). + Trace("remoteAllowList.Allow") } + if !allow { + return false + } + + _, found := lh.myVpnNetworksTable.Lookup(to) + if found { + return false + } + return true } // unlockedShouldAddV4 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool { - allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip)) +func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool { + udpAddr := protoV4AddrPortToNetAddrPort(to) + allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") + lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). + Trace("remoteAllowList.Allow") } - if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) { + if !allow { + return false + } + + _, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) + if found { return false } @@ -663,83 +684,43 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bo } // unlockedShouldAddV6 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool { - allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, to.Hi, to.Lo) +func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool { + udpAddr := protoV6AddrPortToNetAddrPort(to) + allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") + lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). + Trace("remoteAllowList.Allow") } - // We don't check our vpn network here because nebula does not support ipv6 on the inside if !allow { return false } + _, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) + if found { + return false + } + return true } -func lhIp6ToIp(v *Ip6AndPort) net.IP { - ip := make(net.IP, 16) - binary.BigEndian.PutUint64(ip[:8], v.Hi) - binary.BigEndian.PutUint64(ip[8:], v.Lo) - return ip -} - -func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool { - if _, ok := lh.GetLighthouses()[vpnIp]; ok { +func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool { + if _, ok := lh.GetLighthouses()[vpnAddr]; ok { return true } return false } -func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta { - return &NebulaMeta{ - Type: NebulaMeta_HostQuery, - Details: &NebulaMetaDetails{ - VpnIp: uint32(VpnIp), - }, +// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake +// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially +func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool { + l := lh.GetLighthouses() + for _, a := range vpnAddr { + if _, ok := l[a]; ok { + return true + } } -} - -func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort { - ipp := Ip4AndPort{Port: port} - ipp.Ip = uint32(iputil.Ip2VpnIp(ip)) - return &ipp -} - -func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { - v4Addr := ip.As4() - return &Ip4AndPort{ - Ip: binary.BigEndian.Uint32(v4Addr[:]), - Port: uint32(port), - } -} - -func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort { - return &Ip6AndPort{ - Hi: binary.BigEndian.Uint64(ip[:8]), - Lo: binary.BigEndian.Uint64(ip[8:]), - Port: port, - } -} - -func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { - ip6Addr := ip.As16() - return &Ip6AndPort{ - Hi: binary.BigEndian.Uint64(ip6Addr[:8]), - Lo: binary.BigEndian.Uint64(ip6Addr[8:]), - Port: uint32(port), - } -} -func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr { - ip := ipp.Ip - return udp.NewAddr( - net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)), - uint16(ipp.Port), - ) -} - -func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr { - return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) + return false } func (lh *LightHouse) startQueryWorker() { @@ -757,31 +738,85 @@ func (lh *LightHouse) startQueryWorker() { select { case <-lh.ctx.Done(): return - case ip := <-lh.queryChan: - lh.innerQueryServer(ip, nb, out) + case addr := <-lh.queryChan: + lh.innerQueryServer(addr, nb, out) } } }() } -func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) { - if lh.IsLighthouseIP(ip) { +func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { + if lh.IsLighthouseAddr(addr) { return } - // Send a query to the lighthouses and hope for the best next time - query, err := NewLhQueryByInt(ip).Marshal() - if err != nil { - lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload") - return + msg := &NebulaMeta{ + Type: NebulaMeta_HostQuery, + Details: &NebulaMetaDetails{}, } + var v1Query, v2Query []byte + var err error + var v cert.Version + queried := 0 lighthouses := lh.GetLighthouses() - lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses))) - for n := range lighthouses { - lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out) + for lhVpnAddr := range lighthouses { + hi := lh.ifce.GetHostInfo(lhVpnAddr) + if hi != nil { + v = hi.ConnectionState.myCert.Version() + } else { + v = lh.ifce.GetCertState().defaultVersion + } + + if v == cert.Version1 { + if !addr.Is4() { + lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr). + Error("Can't query lighthouse for v6 address using a v1 protocol") + continue + } + + if v1Query == nil { + b := addr.As4() + msg.Details.VpnAddr = nil + msg.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) + + v1Query, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("queryVpnAddr", addr). + WithField("lighthouseAddr", lhVpnAddr). + Error("Failed to marshal lighthouse v1 query payload") + continue + } + } + + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Query, nb, out) + queried++ + + } else if v == cert.Version2 { + if v2Query == nil { + msg.Details.OldVpnAddr = 0 + msg.Details.VpnAddr = netAddrToProtoAddr(addr) + + v2Query, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("queryVpnAddr", addr). + WithField("lighthouseAddr", lhVpnAddr). + Error("Failed to marshal lighthouse v2 query payload") + continue + } + } + + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Query, nb, out) + queried++ + + } else { + lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v) + continue + } } + + lh.metricTx(NebulaMeta_HostQuery, int64(queried)) } func (lh *LightHouse) StartUpdateWorker() { @@ -811,60 +846,120 @@ func (lh *LightHouse) StartUpdateWorker() { } func (lh *LightHouse) SendUpdate() { - var v4 []*Ip4AndPort - var v6 []*Ip6AndPort + var v4 []*V4AddrPort + var v6 []*V6AddrPort for _, e := range lh.GetAdvertiseAddrs() { - if ip := e.ip.To4(); ip != nil { - v4 = append(v4, NewIp4AndPort(e.ip, uint32(e.port))) + if e.Addr().Is4() { + v4 = append(v4, netAddrToProtoV4AddrPort(e.Addr(), e.Port())) } else { - v6 = append(v6, NewIp6AndPort(e.ip, uint32(e.port))) + v6 = append(v6, netAddrToProtoV6AddrPort(e.Addr(), e.Port())) } } lal := lh.GetLocalAllowList() - for _, e := range *localIps(lh.l, lal) { - if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) { + for _, e := range localAddrs(lh.l, lal) { + _, found := lh.myVpnNetworksTable.Lookup(e) + if found { continue } - // Only add IPs that aren't my VPN/tun IP - if ip := e.To4(); ip != nil { - v4 = append(v4, NewIp4AndPort(e, lh.nebulaPort)) + // Only add addrs that aren't my VPN/tun networks + if e.Is4() { + v4 = append(v4, netAddrToProtoV4AddrPort(e, uint16(lh.nebulaPort))) } else { - v6 = append(v6, NewIp6AndPort(e, lh.nebulaPort)) + v6 = append(v6, netAddrToProtoV6AddrPort(e, uint16(lh.nebulaPort))) } } - var relays []uint32 - for _, r := range lh.GetRelaysForMe() { - relays = append(relays, (uint32)(r)) - } - - m := &NebulaMeta{ - Type: NebulaMeta_HostUpdateNotification, - Details: &NebulaMetaDetails{ - VpnIp: uint32(lh.myVpnIp), - Ip4AndPorts: v4, - Ip6AndPorts: v6, - RelayVpnIp: relays, - }, - } - - lighthouses := lh.GetLighthouses() - lh.metricTx(NebulaMeta_HostUpdateNotification, int64(len(lighthouses))) nb := make([]byte, 12, 12) out := make([]byte, mtu) - mm, err := m.Marshal() - if err != nil { - lh.l.WithError(err).Error("Error while marshaling for lighthouse update") - return + var v1Update, v2Update []byte + var err error + updated := 0 + lighthouses := lh.GetLighthouses() + + for lhVpnAddr := range lighthouses { + var v cert.Version + hi := lh.ifce.GetHostInfo(lhVpnAddr) + if hi != nil { + v = hi.ConnectionState.myCert.Version() + } else { + v = lh.ifce.GetCertState().defaultVersion + } + if v == cert.Version1 { + if v1Update == nil { + if !lh.myVpnNetworks[0].Addr().Is4() { + lh.l.WithField("lighthouseAddr", lhVpnAddr). + Warn("cannot update lighthouse using v1 protocol without an IPv4 address") + continue + } + var relays []uint32 + for _, r := range lh.GetRelaysForMe() { + if !r.Is4() { + continue + } + b := r.As4() + relays = append(relays, binary.BigEndian.Uint32(b[:])) + } + b := lh.myVpnNetworks[0].Addr().As4() + msg := NebulaMeta{ + Type: NebulaMeta_HostUpdateNotification, + Details: &NebulaMetaDetails{ + V4AddrPorts: v4, + V6AddrPorts: v6, + OldRelayVpnAddrs: relays, + OldVpnAddr: binary.BigEndian.Uint32(b[:]), + }, + } + + v1Update, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). + Error("Error while marshaling for lighthouse v1 update") + continue + } + } + + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Update, nb, out) + updated++ + + } else if v == cert.Version2 { + if v2Update == nil { + var relays []*Addr + for _, r := range lh.GetRelaysForMe() { + relays = append(relays, netAddrToProtoAddr(r)) + } + + msg := NebulaMeta{ + Type: NebulaMeta_HostUpdateNotification, + Details: &NebulaMetaDetails{ + V4AddrPorts: v4, + V6AddrPorts: v6, + RelayVpnAddrs: relays, + VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()), + }, + } + + v2Update, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). + Error("Error while marshaling for lighthouse v2 update") + continue + } + } + + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Update, nb, out) + updated++ + + } else { + lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v) + continue + } } - for vpnIp := range lighthouses { - lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out) - } + lh.metricTx(NebulaMeta_HostUpdateNotification, int64(updated)) } type LightHouseHandler struct { @@ -907,34 +1002,29 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { lhh.meta.Reset() // Keep the array memory around - details.Ip4AndPorts = details.Ip4AndPorts[:0] - details.Ip6AndPorts = details.Ip6AndPorts[:0] - details.RelayVpnIp = details.RelayVpnIp[:0] + details.V4AddrPorts = details.V4AddrPorts[:0] + details.V6AddrPorts = details.V6AddrPorts[:0] + details.RelayVpnAddrs = details.RelayVpnAddrs[:0] + details.OldRelayVpnAddrs = details.OldRelayVpnAddrs[:0] + details.OldVpnAddr = 0 + details.VpnAddr = nil lhh.meta.Details = details return lhh.meta } -func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc { - return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) { - lhh.HandleRequest(rAddr, vpnIp, p, f) - } -} - -func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) { +func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs []netip.Addr, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr). + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). Error("Failed to unmarshal lighthouse packet") - //TODO: send recv_error? return } if n.Details == nil { - lhh.l.WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr). + lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). Error("Invalid lighthouse update") - //TODO: send recv_error? return } @@ -942,24 +1032,24 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, switch n.Type { case NebulaMeta_HostQuery: - lhh.handleHostQuery(n, vpnIp, rAddr, w) + lhh.handleHostQuery(n, fromVpnAddrs, rAddr, w) case NebulaMeta_HostQueryReply: - lhh.handleHostQueryReply(n, vpnIp) + lhh.handleHostQueryReply(n, fromVpnAddrs) case NebulaMeta_HostUpdateNotification: - lhh.handleHostUpdateNotification(n, vpnIp, w) + lhh.handleHostUpdateNotification(n, fromVpnAddrs, w) case NebulaMeta_HostMovedNotification: case NebulaMeta_HostPunchNotification: - lhh.handleHostPunchNotification(n, vpnIp, w) + lhh.handleHostPunchNotification(n, fromVpnAddrs, w) case NebulaMeta_HostUpdateNotificationAck: // noop } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -968,15 +1058,37 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, return } - //TODO: we can DRY this further - reqVpnIp := n.Details.VpnIp - //TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data - found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) { + useVersion := cert.Version1 + var queryVpnAddr netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + queryVpnAddr = netip.AddrFrom4(b) + useVersion = 1 + } else if n.Details.VpnAddr != nil { + queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + useVersion = 2 + } else { + if lhh.l.Level >= logrus.DebugLevel { + lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery") + } + return + } + + found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnAddr, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostQueryReply - n.Details.VpnIp = reqVpnIp + if useVersion == cert.Version1 { + if !queryVpnAddr.Is4() { + return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery") + } + b := queryVpnAddr.As4() + n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) + } else { + n.Details.VpnAddr = netAddrToProtoAddr(queryVpnAddr) + } - lhh.coalesceAnswers(c, n) + lhh.coalesceAnswers(useVersion, c, n) return n.MarshalTo(lhh.pb) }) @@ -986,20 +1098,51 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, } if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host query reply") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply") return } lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) - // This signals the other side to punch some zero byte udp packets - found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) { + lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w) +} + +// sendHostPunchNotification signals the other side to punch some zero byte udp packets +func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, punchNotifDest netip.Addr, w EncWriter) { + whereToPunch := fromVpnAddrs[0] + found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification - n.Details.VpnIp = uint32(vpnIp) + targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest) + var useVersion cert.Version + if targetHI == nil { + useVersion = lhh.lh.ifce.GetCertState().defaultVersion + } else { + crt := targetHI.GetCert().Certificate + useVersion = crt.Version() + // we can only retarget if we have a hostinfo + newDest, ok := findNetworkUnion(crt.Networks(), fromVpnAddrs) + if ok { + whereToPunch = newDest + } else { + //TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee + //choosing to do nothing for now, but maybe we return an error? + } + } - lhh.coalesceAnswers(c, n) + if useVersion == cert.Version1 { + if !whereToPunch.Is4() { + return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery") + } + b := whereToPunch.As4() + n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) + } else if useVersion == cert.Version2 { + n.Details.VpnAddr = netAddrToProtoAddr(whereToPunch) + } else { + return 0, errors.New("unsupported version") + } + lhh.coalesceAnswers(useVersion, c, n) return n.MarshalTo(lhh.pb) }) @@ -1009,110 +1152,175 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, } if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host was queried for") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for") return } lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnAddr(header.LightHouse, 0, punchNotifDest, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { +func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) { if c.v4 != nil { if c.v4.learned != nil { - n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.learned) + n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.learned) } if c.v4.reported != nil && len(c.v4.reported) > 0 { - n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.reported...) + n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.reported...) } } if c.v6 != nil { if c.v6.learned != nil { - n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.learned) + n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.learned) } if c.v6.reported != nil && len(c.v6.reported) > 0 { - n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.reported...) + n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.reported...) } } if c.relay != nil { - n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, c.relay.relay...) + if v == cert.Version1 { + b := [4]byte{} + for _, r := range c.relay.relay { + if !r.Is4() { + continue + } + + b = r.As4() + n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:])) + } + + } else if v == cert.Version2 { + for _, r := range c.relay.relay { + n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r)) + } + + } else { + //TODO: CERT-V2 don't panic + panic("unsupported version") + } } } -func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) { - if !lhh.lh.IsLighthouseIP(vpnIp) { +func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs []netip.Addr) { + if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) { return } lhh.lh.Lock() - am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp)) + + var certVpnAddr netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + certVpnAddr = netip.AddrFrom4(b) + } else if n.Details.VpnAddr != nil { + certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + } + relays := n.Details.GetRelays() + + am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr}) am.Lock() lhh.lh.Unlock() - certVpnIp := iputil.VpnIp(n.Details.VpnIp) - am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) - am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp) + am.unlockedSetV4(fromVpnAddrs[0], certVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(fromVpnAddrs[0], certVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetRelay(fromVpnAddrs[0], relays) am.Unlock() // Non-blocking attempt to trigger, skip if it would block select { - case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp): + case lhh.lh.handshakeTrigger <- certVpnAddr: default: } } -func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { +func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp) + lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs) } return } - //Simple check that the host sent this not someone else - if n.Details.VpnIp != uint32(vpnIp) { + var detailsVpnAddr netip.Addr + useVersion := cert.Version1 + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + detailsVpnAddr = netip.AddrFrom4(b) + useVersion = cert.Version1 + } else if n.Details.VpnAddr != nil { + detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + useVersion = cert.Version2 + } else { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update") + lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification") } return } + //TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4? + //TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right? + //Simple check that the host sent this not someone else + if !slices.Contains(fromVpnAddrs, detailsVpnAddr) { + if lhh.l.Level >= logrus.DebugLevel { + lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update") + } + return + } + + relays := n.Details.GetRelays() + lhh.lh.Lock() - am := lhh.lh.unlockedGetRemoteList(vpnIp) + am := lhh.lh.unlockedGetRemoteList(fromVpnAddrs) am.Lock() lhh.lh.Unlock() - certVpnIp := iputil.VpnIp(n.Details.VpnIp) - am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) - am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp) + am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetRelay(fromVpnAddrs[0], relays) am.Unlock() n = lhh.resetMeta() n.Type = NebulaMeta_HostUpdateNotificationAck - n.Details.VpnIp = uint32(vpnIp) - ln, err := n.MarshalTo(lhh.pb) + if useVersion == cert.Version1 { + if !fromVpnAddrs[0].Is4() { + lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") + return + } + vpnAddrB := fromVpnAddrs[0].As4() + n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:]) + } else if useVersion == cert.Version2 { + n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0]) + } else { + lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version") + return + } + + ln, err := n.MarshalTo(lhh.pb) if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host update ack") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack") return } lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { - if !lhh.lh.IsLighthouseIP(vpnIp) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { + //It's possible the lighthouse is communicating with us using a non primary vpn addr, + //which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs. + //maybe one day we'll have a better idea, if it matters. + if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) { return } empty := []byte{0} - punch := func(vpnPeer *udp.Addr) { - if vpnPeer == nil { + punch := func(vpnPeer netip.AddrPort) { + if !vpnPeer.IsValid() { return } @@ -1123,39 +1331,123 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i }() if lhh.l.Level >= logrus.DebugLevel { - //TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) - lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp)) + var logVpnAddr netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + logVpnAddr = netip.AddrFrom4(b) + } else if n.Details.VpnAddr != nil { + logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + } + lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr) } } - for _, a := range n.Details.Ip4AndPorts { - punch(NewUDPAddrFromLH4(a)) + for _, a := range n.Details.V4AddrPorts { + punch(protoV4AddrPortToNetAddrPort(a)) } - for _, a := range n.Details.Ip6AndPorts { - punch(NewUDPAddrFromLH6(a)) + for _, a := range n.Details.V6AddrPorts { + punch(protoV6AddrPortToNetAddrPort(a)) } // This sends a nebula test packet to the host trying to contact us. In the case // of a double nat or other difficult scenario, this may help establish // a tunnel. if lhh.lh.punchy.GetRespond() { - queryVpnIp := iputil.VpnIp(n.Details.VpnIp) + var queryVpnAddr netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + queryVpnAddr = netip.AddrFrom4(b) + } else if n.Details.VpnAddr != nil { + queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + } + go func() { time.Sleep(lhh.lh.punchy.GetRespondDelay()) if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", queryVpnIp) + lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr) } //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine // managed by a channel. - w.SendMessageToVpnIp(header.Test, header.TestRequest, queryVpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) }() } } -// ipMaskContains checks if testIp is contained by ip after applying a cidr -// zeros is 32 - bits from net.IPMask.Size() -func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool { - return (testIp^ip)>>zeros == 0 +func protoAddrToNetAddr(addr *Addr) netip.Addr { + b := [16]byte{} + binary.BigEndian.PutUint64(b[:8], addr.Hi) + binary.BigEndian.PutUint64(b[8:], addr.Lo) + return netip.AddrFrom16(b).Unmap() +} + +func protoV4AddrPortToNetAddrPort(ap *V4AddrPort) netip.AddrPort { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], ap.Addr) + return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ap.Port)) +} + +func protoV6AddrPortToNetAddrPort(ap *V6AddrPort) netip.AddrPort { + b := [16]byte{} + binary.BigEndian.PutUint64(b[:8], ap.Hi) + binary.BigEndian.PutUint64(b[8:], ap.Lo) + return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ap.Port)) +} + +func netAddrToProtoAddr(addr netip.Addr) *Addr { + b := addr.As16() + return &Addr{ + Hi: binary.BigEndian.Uint64(b[:8]), + Lo: binary.BigEndian.Uint64(b[8:]), + } +} + +func netAddrToProtoV4AddrPort(addr netip.Addr, port uint16) *V4AddrPort { + v4Addr := addr.As4() + return &V4AddrPort{ + Addr: binary.BigEndian.Uint32(v4Addr[:]), + Port: uint32(port), + } +} + +func netAddrToProtoV6AddrPort(addr netip.Addr, port uint16) *V6AddrPort { + v6Addr := addr.As16() + return &V6AddrPort{ + Hi: binary.BigEndian.Uint64(v6Addr[:8]), + Lo: binary.BigEndian.Uint64(v6Addr[8:]), + Port: uint32(port), + } +} + +func (d *NebulaMetaDetails) GetRelays() []netip.Addr { + var relays []netip.Addr + if len(d.OldRelayVpnAddrs) > 0 { + b := [4]byte{} + for _, r := range d.OldRelayVpnAddrs { + binary.BigEndian.PutUint32(b[:], r) + relays = append(relays, netip.AddrFrom4(b)) + } + } + + if len(d.RelayVpnAddrs) > 0 { + for _, r := range d.RelayVpnAddrs { + relays = append(relays, protoAddrToNetAddr(r)) + } + } + return relays +} + +// FindNetworkUnion returns the first netip.Addr contained in the list of provided netip.Prefix, if able +func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, bool) { + for i := range prefixes { + for j := range addrs { + if prefixes[i].Contains(addrs[j]) { + return addrs[j], true + } + } + } + return netip.Addr{}, false } diff --git a/lighthouse_test.go b/lighthouse_test.go index 66427e3..6a541c2 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -2,150 +2,151 @@ package nebula import ( "context" + "encoding/binary" "fmt" - "net" + "net/netip" "testing" + "github.com/gaissmai/bart" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" - "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" - "gopkg.in/yaml.v2" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) -//TODO: Add a test to ensure udpAddr is copied and not reused - func TestOldIPv4Only(t *testing.T) { // This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility b := []byte{8, 129, 130, 132, 80, 16, 10} - var m Ip4AndPort + var m V4AddrPort err := m.Unmarshal(b) - assert.NoError(t, err) - assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String()) -} - -func TestNewLhQuery(t *testing.T) { - myIp := net.ParseIP("192.1.1.1") - myIpint := iputil.Ip2VpnIp(myIp) - - // Generating a new lh query should work - a := NewLhQueryByInt(myIpint) - - // The result should be a nebulameta protobuf - assert.IsType(t, &NebulaMeta{}, a) - - // It should also Marshal fine - b, err := a.Marshal() - assert.Nil(t, err) - - // and then Unmarshal fine - n := &NebulaMeta{} - err = n.Unmarshal(b) - assert.Nil(t, err) - + require.NoError(t, err) + ip := netip.MustParseAddr("10.1.1.1") + bp := ip.As4() + assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr()) } func Test_lhStaticMapping(t *testing.T) { l := test.NewLogger() - _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16") + myVpnNet := netip.MustParsePrefix("10.128.0.1/16") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } lh1 := "10.128.0.2" c := config.NewC(l) - c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} - c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} - _, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) - assert.Nil(t, err) + c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1}} + c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}} + _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + require.NoError(t, err) lh2 := "10.128.0.3" c = config.NewC(l) - c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} - c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} - _, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) - assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") + c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1, lh2}} + c.Settings["static_host_map"] = map[string]any{lh1: []any{"100.1.1.1:4242"}} + _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } func TestReloadLighthouseInterval(t *testing.T) { l := test.NewLogger() - _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16") + myVpnNet := netip.MustParsePrefix("10.128.0.1/16") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } lh1 := "10.128.0.2" c := config.NewC(l) - c.Settings["lighthouse"] = map[interface{}]interface{}{ - "hosts": []interface{}{lh1}, + c.Settings["lighthouse"] = map[string]any{ + "hosts": []any{lh1}, "interval": "1s", } - c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) - assert.NoError(t, err) + c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}} + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + require.NoError(t, err) lh.ifce = &mockEncWriter{} // The first one routine is kicked off by main.go currently, lets make sure that one dies - c.ReloadConfigString("lighthouse:\n interval: 5") + require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5")) assert.Equal(t, int64(5), lh.interval.Load()) // Subsequent calls are killed off by the LightHouse.Reload function - c.ReloadConfigString("lighthouse:\n interval: 10") + require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10")) assert.Equal(t, int64(10), lh.interval.Load()) // If this completes then nothing is stealing our reload routine - c.ReloadConfigString("lighthouse:\n interval: 11") + require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11")) assert.Equal(t, int64(11), lh.interval.Load()) } func BenchmarkLighthouseHandleRequest(b *testing.B) { l := test.NewLogger() - _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0") - - c := config.NewC(l) - lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) - if !assert.NoError(b, err) { - b.Fatal() + myVpnNet := netip.MustParsePrefix("10.128.0.1/0") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, } - hAddr := udp.NewAddrFromString("4.5.6.7:12345") - hAddr2 := udp.NewAddrFromString("4.5.6.7:12346") - lh.addrMap[3] = NewRemoteList(nil) - lh.addrMap[3].unlockedSetV4( - 3, - 3, - []*Ip4AndPort{ - NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)), - NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)), + c := config.NewC(l) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + require.NoError(b, err) + + hAddr := netip.MustParseAddrPort("4.5.6.7:12345") + hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") + + vpnIp3 := netip.MustParseAddr("0.0.0.3") + lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil) + lh.addrMap[vpnIp3].unlockedSetV4( + vpnIp3, + vpnIp3, + []*V4AddrPort{ + netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()), + netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()), }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) - rAddr := udp.NewAddrFromString("1.2.2.3:12345") - rAddr2 := udp.NewAddrFromString("1.2.2.3:12346") - lh.addrMap[2] = NewRemoteList(nil) - lh.addrMap[2].unlockedSetV4( - 3, - 3, - []*Ip4AndPort{ - NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)), - NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)), + rAddr := netip.MustParseAddrPort("1.2.2.3:12345") + rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346") + vpnIp2 := netip.MustParseAddr("0.0.0.3") + lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil) + lh.addrMap[vpnIp2].unlockedSetV4( + vpnIp3, + vpnIp3, + []*V4AddrPort{ + netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()), + netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()), }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) mw := &mockEncWriter{} + hi := []netip.Addr{vpnIp2} b.Run("notfound", func(b *testing.B) { lhh := lh.NewRequestHandler() req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: 4, - Ip4AndPorts: nil, + OldVpnAddr: 4, + V4AddrPorts: nil, }, } p, err := req.Marshal() - assert.NoError(b, err) + require.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, 2, p, mw) + lhh.HandleRequest(rAddr, hi, p, mw) } }) b.Run("found", func(b *testing.B) { @@ -153,15 +154,15 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: 3, - Ip4AndPorts: nil, + OldVpnAddr: 3, + V4AddrPorts: nil, }, } p, err := req.Marshal() - assert.NoError(b, err) + require.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, 2, p, mw) + lhh.HandleRequest(rAddr, hi, p, mw) } }) } @@ -169,71 +170,80 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { func TestLighthouse_Memory(t *testing.T) { l := test.NewLogger() - myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242} - myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242} - myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242} - myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242} - myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242} - myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243} - myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244} - myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245} - myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246} - myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247} - myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248} - myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249} - myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2")) + myUdpAddr0 := netip.MustParseAddrPort("10.0.0.2:4242") + myUdpAddr1 := netip.MustParseAddrPort("192.168.0.2:4242") + myUdpAddr2 := netip.MustParseAddrPort("172.16.0.2:4242") + myUdpAddr3 := netip.MustParseAddrPort("100.152.0.2:4242") + myUdpAddr4 := netip.MustParseAddrPort("24.15.0.2:4242") + myUdpAddr5 := netip.MustParseAddrPort("192.168.0.2:4243") + myUdpAddr6 := netip.MustParseAddrPort("192.168.0.2:4244") + myUdpAddr7 := netip.MustParseAddrPort("192.168.0.2:4245") + myUdpAddr8 := netip.MustParseAddrPort("192.168.0.2:4246") + myUdpAddr9 := netip.MustParseAddrPort("192.168.0.2:4247") + myUdpAddr10 := netip.MustParseAddrPort("192.168.0.2:4248") + myUdpAddr11 := netip.MustParseAddrPort("192.168.0.2:4249") + myVpnIp := netip.MustParseAddr("10.128.0.2") - theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242} - theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242} - theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242} - theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242} - theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242} - theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3")) + theirUdpAddr0 := netip.MustParseAddrPort("10.0.0.3:4242") + theirUdpAddr1 := netip.MustParseAddrPort("192.168.0.3:4242") + theirUdpAddr2 := netip.MustParseAddrPort("172.16.0.3:4242") + theirUdpAddr3 := netip.MustParseAddrPort("100.152.0.3:4242") + theirUdpAddr4 := netip.MustParseAddrPort("24.15.0.3:4242") + theirVpnIp := netip.MustParseAddr("10.128.0.3") c := config.NewC(l) - c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} - c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) - assert.NoError(t, err) + c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true} + c.Settings["listen"] = map[string]any{"port": 4242} + + myVpnNet := netip.MustParsePrefix("10.128.0.1/24") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh.ifce = &mockEncWriter{} + require.NoError(t, err) lhh := lh.NewRequestHandler() // Test that my first update responds with just that - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh) r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2) // Ensure we don't accumulate addresses - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3) // Grow it back to 2 - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) // Update a different host and ask about it - newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) + newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) // Have both hosts ask about the other r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) // Make sure we didn't get changed r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) // Ensure proper ordering and limiting // Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe) newLHHostUpdate( myUdpAddr0, myVpnIp, - []*udp.Addr{ + []netip.AddrPort{ myUdpAddr1, myUdpAddr2, myUdpAddr3, @@ -251,46 +261,60 @@ func TestLighthouse_Memory(t *testing.T) { r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray( t, - r.msg.Details.Ip4AndPorts, + r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9, ) // Make sure we won't add ips in our vpn network - bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242} - bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242} - good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242} - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh) + bad1 := netip.MustParseAddrPort("10.128.0.99:4242") + bad2 := netip.MustParseAddrPort("10.128.0.100:4242") + good := netip.MustParseAddrPort("1.128.0.99:4242") + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, good) } func TestLighthouse_reload(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} - c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) - assert.NoError(t, err) + c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true} + c.Settings["listen"] = map[string]any{"port": 4242} - nc := map[interface{}]interface{}{ - "static_host_map": map[interface{}]interface{}{ - "10.128.0.2": []interface{}{"1.1.1.1:4242"}, + myVpnNet := netip.MustParsePrefix("10.128.0.1/24") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } + + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + require.NoError(t, err) + + nc := map[string]any{ + "static_host_map": map[string]any{ + "10.128.0.2": []any{"1.1.1.1:4242"}, }, } rc, err := yaml.Marshal(nc) - assert.NoError(t, err) + require.NoError(t, err) c.ReloadConfigString(string(rc)) err = lh.reload(c, false) - assert.NoError(t, err) + require.NoError(t, err) } -func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply { +func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { req := &NebulaMeta{ - Type: NebulaMeta_HostQuery, - Details: &NebulaMetaDetails{ - VpnIp: uint32(queryVpnIp), - }, + Type: NebulaMeta_HostQuery, + Details: &NebulaMetaDetails{}, + } + + if queryVpnIp.Is4() { + bip := queryVpnIp.As4() + req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:]) + } else { + req.Details.VpnAddr = netAddrToProtoAddr(queryVpnIp) } b, err := req.Marshal() @@ -302,21 +326,29 @@ func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh w := &testEncWriter{ metaFilter: &filter, } - lhh.HandleRequest(fromAddr, myVpnIp, b, w) + lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w) return w.lastReply } -func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) { +func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) { req := &NebulaMeta{ - Type: NebulaMeta_HostUpdateNotification, - Details: &NebulaMetaDetails{ - VpnIp: uint32(vpnIp), - Ip4AndPorts: make([]*Ip4AndPort, len(addrs)), - }, + Type: NebulaMeta_HostUpdateNotification, + Details: &NebulaMetaDetails{}, } - for k, v := range addrs { - req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)} + if vpnIp.Is4() { + bip := vpnIp.As4() + req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:]) + } else { + req.Details.VpnAddr = netAddrToProtoAddr(vpnIp) + } + + for _, v := range addrs { + if v.Addr().Is4() { + req.Details.V4AddrPorts = append(req.Details.V4AddrPorts, netAddrToProtoV4AddrPort(v.Addr(), v.Port())) + } else { + req.Details.V6AddrPorts = append(req.Details.V6AddrPorts, netAddrToProtoV6AddrPort(v.Addr(), v.Port())) + } } b, err := req.Marshal() @@ -325,96 +357,25 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, } w := &testEncWriter{} - lhh.HandleRequest(fromAddr, vpnIp, b, w) -} - -//TODO: this is a RemoteList test -//func Test_lhRemoteAllowList(t *testing.T) { -// l := NewLogger() -// c := NewConfig(l) -// c.Settings["remoteallowlist"] = map[interface{}]interface{}{ -// "10.20.0.0/12": false, -// } -// allowList, err := c.GetAllowList("remoteallowlist", false) -// assert.Nil(t, err) -// -// lh1 := "10.128.0.2" -// lh1IP := net.ParseIP(lh1) -// -// udpServer, _ := NewListener(l, "0.0.0.0", 0, true) -// -// lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) -// lh.SetRemoteAllowList(allowList) -// -// // A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap -// remote1IP := net.ParseIP("10.20.0.3") -// remotes := lh.unlockedGetRemoteList(ip2int(remote1IP)) -// remotes.unlockedPrependV4(ip2int(remote1IP), NewIp4AndPort(remote1IP, 4242)) -// assert.NotNil(t, lh.addrMap[ip2int(remote1IP)]) -// assert.Empty(t, lh.addrMap[ip2int(remote1IP)].CopyAddrs([]*net.IPNet{})) -// -// // Make sure a good ip enters the cache and addrMap -// remote2IP := net.ParseIP("10.128.0.3") -// remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242)) -// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote2UDPAddr.IP, uint32(remote2UDPAddr.Port)), false, false) -// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr) -// -// // Another good ip gets into the cache, ordering is inverted -// remote3IP := net.ParseIP("10.128.0.4") -// remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243)) -// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote3UDPAddr.IP, uint32(remote3UDPAddr.Port)), false, false) -// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr, remote3UDPAddr) -// -// // If we exceed the length limit we should only have the most recent addresses -// addedAddrs := []*udpAddr{} -// for i := 0; i < 11; i++ { -// remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i)) -// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remoteUDPAddr.IP, uint32(remoteUDPAddr.Port)), false, false) -// // The first entry here is a duplicate, don't add it to the assert list -// if i != 0 { -// addedAddrs = append(addedAddrs, remoteUDPAddr) -// } -// } -// -// // We should only have the last 10 of what we tried to add -// assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses") -// assertUdpAddrInArray( -// t, -// lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), -// addedAddrs[0], -// addedAddrs[1], -// addedAddrs[2], -// addedAddrs[3], -// addedAddrs[4], -// addedAddrs[5], -// addedAddrs[6], -// addedAddrs[7], -// addedAddrs[8], -// addedAddrs[9], -// ) -//} - -func Test_ipMaskContains(t *testing.T) { - assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255")))) - assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1")))) - assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1")))) + lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w) } type testLhReply struct { nebType header.MessageType nebSubType header.MessageSubType - vpnIp iputil.VpnIp + vpnIp netip.Addr msg *NebulaMeta } type testEncWriter struct { - lastReply testLhReply - metaFilter *NebulaMeta_MessageType + lastReply testLhReply + metaFilter *NebulaMeta_MessageType + protocolVersion cert.Version } func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { } -func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) { +func (tw *testEncWriter) Handshake(vpnIp netip.Addr) { } func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) { @@ -424,7 +385,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M tw.lastReply = testLhReply{ nebType: t, nebSubType: st, - vpnIp: hostinfo.vpnIp, + vpnIp: hostinfo.vpnAddrs[0], msg: msg, } } @@ -434,7 +395,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M } } -func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) { +func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) { msg := &NebulaMeta{} err := msg.Unmarshal(p) if tw.metaFilter == nil || msg.Type == *tw.metaFilter { @@ -451,36 +412,84 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess } } +func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo { + return nil +} + +func (tw *testEncWriter) GetCertState() *CertState { + return &CertState{defaultVersion: tw.protocolVersion} +} + // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match -func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) { +func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) { if !assert.Len(t, have, len(want)) { return } for k, w := range want { - if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) { - assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have))) + h := protoV4AddrPortToNetAddrPort(have[k]) + if !(h == w) { + assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h)) } } } -// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match -func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) { - if !assert.Len(t, have, len(want)) { - return - } +func Test_findNetworkUnion(t *testing.T) { + var out netip.Addr + var ok bool - for k, w := range want { - if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) { - assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have)) - } - } -} + tenDot := netip.MustParsePrefix("10.0.0.0/8") + oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16") + fe80 := netip.MustParsePrefix("fe80::/8") + fc00 := netip.MustParsePrefix("fc00::/7") -func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr { - addrs := make([]*udp.Addr, len(ips)) - for k, v := range ips { - addrs[k] = NewUDPAddrFromLH4(v) - } - return addrs + a1 := netip.MustParseAddr("10.0.0.1") + afe81 := netip.MustParseAddr("fe80::1") + + //simple + out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + + //mixed lengths + out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81}) + assert.True(t, ok) + assert.Equal(t, out, a1) + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + + //mixed family + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81}) + assert.True(t, ok) + assert.Equal(t, out, a1) + + //ordering + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1}) + assert.True(t, ok) + assert.Equal(t, out, afe81) + + //some mismatches + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81}) + assert.True(t, ok) + assert.Equal(t, out, afe81) + out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81}) + assert.True(t, ok) + assert.Equal(t, out, afe81) + + //falsey cases + out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1}) + assert.False(t, ok) + out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1}) + assert.False(t, ok) + out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81}) + assert.False(t, ok) + out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81}) + assert.False(t, ok) } diff --git a/main.go b/main.go index 7a0a0cf..b278fa6 100644 --- a/main.go +++ b/main.go @@ -2,9 +2,9 @@ package nebula import ( "context" - "encoding/binary" "fmt" "net" + "net/netip" "time" "github.com/sirupsen/logrus" @@ -13,10 +13,10 @@ import ( "github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) -type m map[string]interface{} +type m = map[string]any func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { ctx, cancel := context.WithCancel(context.Background()) @@ -60,16 +60,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) } - certificate := pki.GetCertState().Certificate - fw, err := NewFirewallFromConfig(l, certificate, c) + fw, err := NewFirewallFromConfig(l, pki.getCertState(), c) if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") - // TODO: make sure mask is 4 bytes - tunCidr := certificate.Details.Ips[0] - ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) if err != nil { return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) @@ -132,7 +128,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg deviceFactory = overlay.NewDeviceFromConfig } - tun, err = deviceFactory(c, l, tunCidr, routines) + tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } @@ -150,21 +146,25 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if !configTest { rawListenHost := c.GetString("listen.host", "0.0.0.0") - var listenHost *net.IPAddr + var listenHost netip.Addr if rawListenHost == "[::]" { // Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve. - listenHost = &net.IPAddr{IP: net.IPv6zero} + listenHost = netip.IPv6Unspecified() } else { - listenHost, err = net.ResolveIPAddr("ip", rawListenHost) + ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", rawListenHost) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) } + if len(ips) == 0 { + return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) + } + listenHost = ips[0].Unmap() } for i := 0; i < routines; i++ { - l.Infof("listening %q %d", listenHost.IP, port) - udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64)) + l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) + udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) } @@ -178,14 +178,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if err != nil { return nil, util.NewContextualError("Failed to get listening port", nil, err) } - port = int(uPort.Port) + port = int(uPort.Port()) } } } - hostMap := NewHostMapFromConfig(l, tunCidr, c) + hostMap := NewHostMapFromConfig(l, c) punchy := NewPunchyFromConfig(l, c) - lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) + lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) } @@ -201,7 +201,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg handshakeConfig := HandshakeConfig{ tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), - retries: c.GetInt("handshakes.retries", DefaultHandshakeRetries), + retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)), triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), useRelays: useRelays, @@ -228,7 +228,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg Inside: tun, Outside: udpConns[0], pki: pki, - Cipher: c.GetString("cipher", "aes"), Firewall: fw, ServeDns: serveDns, HandshakeManager: handshakeManager, @@ -250,15 +249,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg l: l, } - switch ifConfig.Cipher { - case "aes": - noiseEndianness = binary.BigEndian - case "chachapoly": - noiseEndianness = binary.LittleEndian - default: - return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher) - } - var ifce *Interface if !configTest { ifce, err = NewInterface(ctx, ifConfig) @@ -266,8 +256,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, fmt.Errorf("failed to initialize interface: %s", err) } - // TODO: Better way to attach these, probably want a new interface in InterfaceConfig - // I don't want to make this initial commit too far-reaching though ifce.writers = udpConns lightHouse.ifce = ifce @@ -279,8 +267,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg go handshakeManager.Run(ctx) } - // TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept - // a context so that they can exit when the context is Done. statsStart, err := startStats(l, c, buildVersion, configTest) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err) @@ -290,7 +276,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, nil } - //TODO: check if we _should_ be emitting stats go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10)) attachCommands(l, c, ssh, ifce) @@ -299,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg var dnsStart func() if lightHouse.amLighthouse && serveDns { l.Debugln("Starting dns server") - dnsStart = dnsMain(l, hostMap, c) + dnsStart = dnsMain(l, pki.getCertState(), hostMap, c) } return &Control{ diff --git a/message_metrics.go b/message_metrics.go index 94bb02f..10e8472 100644 --- a/message_metrics.go +++ b/message_metrics.go @@ -7,8 +7,6 @@ import ( "github.com/slackhq/nebula/header" ) -//TODO: this can probably move into the header package - type MessageMetrics struct { rx [][]metrics.Counter tx [][]metrics.Counter diff --git a/metadata.go b/metadata.go deleted file mode 100644 index 6a023ab..0000000 --- a/metadata.go +++ /dev/null @@ -1,18 +0,0 @@ -package nebula - -/* - -import ( - proto "google.golang.org/protobuf/proto" -) - -func HandleMetaProto(p []byte) { - m := &NebulaMeta{} - err := proto.Unmarshal(p, m) - if err != nil { - l.Debugf("problem unmarshaling meta message: %s", err) - } - //fmt.Println(m) -} - -*/ diff --git a/nebula.pb.go b/nebula.pb.go index b3c723a..2fd2ff6 100644 --- a/nebula.pb.go +++ b/nebula.pb.go @@ -96,7 +96,7 @@ func (x NebulaPing_MessageType) String() string { } func (NebulaPing_MessageType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{4, 0} + return fileDescriptor_2d65afa7693df5ef, []int{5, 0} } type NebulaControl_MessageType int32 @@ -124,7 +124,7 @@ func (x NebulaControl_MessageType) String() string { } func (NebulaControl_MessageType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{7, 0} + return fileDescriptor_2d65afa7693df5ef, []int{8, 0} } type NebulaMeta struct { @@ -180,11 +180,13 @@ func (m *NebulaMeta) GetDetails() *NebulaMetaDetails { } type NebulaMetaDetails struct { - VpnIp uint32 `protobuf:"varint,1,opt,name=VpnIp,proto3" json:"VpnIp,omitempty"` - Ip4AndPorts []*Ip4AndPort `protobuf:"bytes,2,rep,name=Ip4AndPorts,proto3" json:"Ip4AndPorts,omitempty"` - Ip6AndPorts []*Ip6AndPort `protobuf:"bytes,4,rep,name=Ip6AndPorts,proto3" json:"Ip6AndPorts,omitempty"` - RelayVpnIp []uint32 `protobuf:"varint,5,rep,packed,name=RelayVpnIp,proto3" json:"RelayVpnIp,omitempty"` - Counter uint32 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` + OldVpnAddr uint32 `protobuf:"varint,1,opt,name=OldVpnAddr,proto3" json:"OldVpnAddr,omitempty"` // Deprecated: Do not use. + VpnAddr *Addr `protobuf:"bytes,6,opt,name=VpnAddr,proto3" json:"VpnAddr,omitempty"` + OldRelayVpnAddrs []uint32 `protobuf:"varint,5,rep,packed,name=OldRelayVpnAddrs,proto3" json:"OldRelayVpnAddrs,omitempty"` // Deprecated: Do not use. + RelayVpnAddrs []*Addr `protobuf:"bytes,7,rep,name=RelayVpnAddrs,proto3" json:"RelayVpnAddrs,omitempty"` + V4AddrPorts []*V4AddrPort `protobuf:"bytes,2,rep,name=V4AddrPorts,proto3" json:"V4AddrPorts,omitempty"` + V6AddrPorts []*V6AddrPort `protobuf:"bytes,4,rep,name=V6AddrPorts,proto3" json:"V6AddrPorts,omitempty"` + Counter uint32 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` } func (m *NebulaMetaDetails) Reset() { *m = NebulaMetaDetails{} } @@ -220,30 +222,46 @@ func (m *NebulaMetaDetails) XXX_DiscardUnknown() { var xxx_messageInfo_NebulaMetaDetails proto.InternalMessageInfo -func (m *NebulaMetaDetails) GetVpnIp() uint32 { +// Deprecated: Do not use. +func (m *NebulaMetaDetails) GetOldVpnAddr() uint32 { if m != nil { - return m.VpnIp + return m.OldVpnAddr } return 0 } -func (m *NebulaMetaDetails) GetIp4AndPorts() []*Ip4AndPort { +func (m *NebulaMetaDetails) GetVpnAddr() *Addr { if m != nil { - return m.Ip4AndPorts + return m.VpnAddr } return nil } -func (m *NebulaMetaDetails) GetIp6AndPorts() []*Ip6AndPort { +// Deprecated: Do not use. +func (m *NebulaMetaDetails) GetOldRelayVpnAddrs() []uint32 { if m != nil { - return m.Ip6AndPorts + return m.OldRelayVpnAddrs } return nil } -func (m *NebulaMetaDetails) GetRelayVpnIp() []uint32 { +func (m *NebulaMetaDetails) GetRelayVpnAddrs() []*Addr { if m != nil { - return m.RelayVpnIp + return m.RelayVpnAddrs + } + return nil +} + +func (m *NebulaMetaDetails) GetV4AddrPorts() []*V4AddrPort { + if m != nil { + return m.V4AddrPorts + } + return nil +} + +func (m *NebulaMetaDetails) GetV6AddrPorts() []*V6AddrPort { + if m != nil { + return m.V6AddrPorts } return nil } @@ -255,23 +273,23 @@ func (m *NebulaMetaDetails) GetCounter() uint32 { return 0 } -type Ip4AndPort struct { - Ip uint32 `protobuf:"varint,1,opt,name=Ip,proto3" json:"Ip,omitempty"` - Port uint32 `protobuf:"varint,2,opt,name=Port,proto3" json:"Port,omitempty"` +type Addr struct { + Hi uint64 `protobuf:"varint,1,opt,name=Hi,proto3" json:"Hi,omitempty"` + Lo uint64 `protobuf:"varint,2,opt,name=Lo,proto3" json:"Lo,omitempty"` } -func (m *Ip4AndPort) Reset() { *m = Ip4AndPort{} } -func (m *Ip4AndPort) String() string { return proto.CompactTextString(m) } -func (*Ip4AndPort) ProtoMessage() {} -func (*Ip4AndPort) Descriptor() ([]byte, []int) { +func (m *Addr) Reset() { *m = Addr{} } +func (m *Addr) String() string { return proto.CompactTextString(m) } +func (*Addr) ProtoMessage() {} +func (*Addr) Descriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{2} } -func (m *Ip4AndPort) XXX_Unmarshal(b []byte) error { +func (m *Addr) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } -func (m *Ip4AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { +func (m *Addr) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { - return xxx_messageInfo_Ip4AndPort.Marshal(b, m, deterministic) + return xxx_messageInfo_Addr.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) @@ -281,86 +299,138 @@ func (m *Ip4AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return b[:n], nil } } -func (m *Ip4AndPort) XXX_Merge(src proto.Message) { - xxx_messageInfo_Ip4AndPort.Merge(m, src) +func (m *Addr) XXX_Merge(src proto.Message) { + xxx_messageInfo_Addr.Merge(m, src) } -func (m *Ip4AndPort) XXX_Size() int { +func (m *Addr) XXX_Size() int { return m.Size() } -func (m *Ip4AndPort) XXX_DiscardUnknown() { - xxx_messageInfo_Ip4AndPort.DiscardUnknown(m) +func (m *Addr) XXX_DiscardUnknown() { + xxx_messageInfo_Addr.DiscardUnknown(m) } -var xxx_messageInfo_Ip4AndPort proto.InternalMessageInfo +var xxx_messageInfo_Addr proto.InternalMessageInfo -func (m *Ip4AndPort) GetIp() uint32 { - if m != nil { - return m.Ip - } - return 0 -} - -func (m *Ip4AndPort) GetPort() uint32 { - if m != nil { - return m.Port - } - return 0 -} - -type Ip6AndPort struct { - Hi uint64 `protobuf:"varint,1,opt,name=Hi,proto3" json:"Hi,omitempty"` - Lo uint64 `protobuf:"varint,2,opt,name=Lo,proto3" json:"Lo,omitempty"` - Port uint32 `protobuf:"varint,3,opt,name=Port,proto3" json:"Port,omitempty"` -} - -func (m *Ip6AndPort) Reset() { *m = Ip6AndPort{} } -func (m *Ip6AndPort) String() string { return proto.CompactTextString(m) } -func (*Ip6AndPort) ProtoMessage() {} -func (*Ip6AndPort) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{3} -} -func (m *Ip6AndPort) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) -} -func (m *Ip6AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - if deterministic { - return xxx_messageInfo_Ip6AndPort.Marshal(b, m, deterministic) - } else { - b = b[:cap(b)] - n, err := m.MarshalToSizedBuffer(b) - if err != nil { - return nil, err - } - return b[:n], nil - } -} -func (m *Ip6AndPort) XXX_Merge(src proto.Message) { - xxx_messageInfo_Ip6AndPort.Merge(m, src) -} -func (m *Ip6AndPort) XXX_Size() int { - return m.Size() -} -func (m *Ip6AndPort) XXX_DiscardUnknown() { - xxx_messageInfo_Ip6AndPort.DiscardUnknown(m) -} - -var xxx_messageInfo_Ip6AndPort proto.InternalMessageInfo - -func (m *Ip6AndPort) GetHi() uint64 { +func (m *Addr) GetHi() uint64 { if m != nil { return m.Hi } return 0 } -func (m *Ip6AndPort) GetLo() uint64 { +func (m *Addr) GetLo() uint64 { if m != nil { return m.Lo } return 0 } -func (m *Ip6AndPort) GetPort() uint32 { +type V4AddrPort struct { + Addr uint32 `protobuf:"varint,1,opt,name=Addr,proto3" json:"Addr,omitempty"` + Port uint32 `protobuf:"varint,2,opt,name=Port,proto3" json:"Port,omitempty"` +} + +func (m *V4AddrPort) Reset() { *m = V4AddrPort{} } +func (m *V4AddrPort) String() string { return proto.CompactTextString(m) } +func (*V4AddrPort) ProtoMessage() {} +func (*V4AddrPort) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{3} +} +func (m *V4AddrPort) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *V4AddrPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_V4AddrPort.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *V4AddrPort) XXX_Merge(src proto.Message) { + xxx_messageInfo_V4AddrPort.Merge(m, src) +} +func (m *V4AddrPort) XXX_Size() int { + return m.Size() +} +func (m *V4AddrPort) XXX_DiscardUnknown() { + xxx_messageInfo_V4AddrPort.DiscardUnknown(m) +} + +var xxx_messageInfo_V4AddrPort proto.InternalMessageInfo + +func (m *V4AddrPort) GetAddr() uint32 { + if m != nil { + return m.Addr + } + return 0 +} + +func (m *V4AddrPort) GetPort() uint32 { + if m != nil { + return m.Port + } + return 0 +} + +type V6AddrPort struct { + Hi uint64 `protobuf:"varint,1,opt,name=Hi,proto3" json:"Hi,omitempty"` + Lo uint64 `protobuf:"varint,2,opt,name=Lo,proto3" json:"Lo,omitempty"` + Port uint32 `protobuf:"varint,3,opt,name=Port,proto3" json:"Port,omitempty"` +} + +func (m *V6AddrPort) Reset() { *m = V6AddrPort{} } +func (m *V6AddrPort) String() string { return proto.CompactTextString(m) } +func (*V6AddrPort) ProtoMessage() {} +func (*V6AddrPort) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{4} +} +func (m *V6AddrPort) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *V6AddrPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_V6AddrPort.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *V6AddrPort) XXX_Merge(src proto.Message) { + xxx_messageInfo_V6AddrPort.Merge(m, src) +} +func (m *V6AddrPort) XXX_Size() int { + return m.Size() +} +func (m *V6AddrPort) XXX_DiscardUnknown() { + xxx_messageInfo_V6AddrPort.DiscardUnknown(m) +} + +var xxx_messageInfo_V6AddrPort proto.InternalMessageInfo + +func (m *V6AddrPort) GetHi() uint64 { + if m != nil { + return m.Hi + } + return 0 +} + +func (m *V6AddrPort) GetLo() uint64 { + if m != nil { + return m.Lo + } + return 0 +} + +func (m *V6AddrPort) GetPort() uint32 { if m != nil { return m.Port } @@ -376,7 +446,7 @@ func (m *NebulaPing) Reset() { *m = NebulaPing{} } func (m *NebulaPing) String() string { return proto.CompactTextString(m) } func (*NebulaPing) ProtoMessage() {} func (*NebulaPing) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{4} + return fileDescriptor_2d65afa7693df5ef, []int{5} } func (m *NebulaPing) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -428,7 +498,7 @@ func (m *NebulaHandshake) Reset() { *m = NebulaHandshake{} } func (m *NebulaHandshake) String() string { return proto.CompactTextString(m) } func (*NebulaHandshake) ProtoMessage() {} func (*NebulaHandshake) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{5} + return fileDescriptor_2d65afa7693df5ef, []int{6} } func (m *NebulaHandshake) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -477,13 +547,14 @@ type NebulaHandshakeDetails struct { ResponderIndex uint32 `protobuf:"varint,3,opt,name=ResponderIndex,proto3" json:"ResponderIndex,omitempty"` Cookie uint64 `protobuf:"varint,4,opt,name=Cookie,proto3" json:"Cookie,omitempty"` Time uint64 `protobuf:"varint,5,opt,name=Time,proto3" json:"Time,omitempty"` + CertVersion uint32 `protobuf:"varint,8,opt,name=CertVersion,proto3" json:"CertVersion,omitempty"` } func (m *NebulaHandshakeDetails) Reset() { *m = NebulaHandshakeDetails{} } func (m *NebulaHandshakeDetails) String() string { return proto.CompactTextString(m) } func (*NebulaHandshakeDetails) ProtoMessage() {} func (*NebulaHandshakeDetails) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{6} + return fileDescriptor_2d65afa7693df5ef, []int{7} } func (m *NebulaHandshakeDetails) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -547,19 +618,28 @@ func (m *NebulaHandshakeDetails) GetTime() uint64 { return 0 } +func (m *NebulaHandshakeDetails) GetCertVersion() uint32 { + if m != nil { + return m.CertVersion + } + return 0 +} + type NebulaControl struct { Type NebulaControl_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaControl_MessageType" json:"Type,omitempty"` InitiatorRelayIndex uint32 `protobuf:"varint,2,opt,name=InitiatorRelayIndex,proto3" json:"InitiatorRelayIndex,omitempty"` ResponderRelayIndex uint32 `protobuf:"varint,3,opt,name=ResponderRelayIndex,proto3" json:"ResponderRelayIndex,omitempty"` - RelayToIp uint32 `protobuf:"varint,4,opt,name=RelayToIp,proto3" json:"RelayToIp,omitempty"` - RelayFromIp uint32 `protobuf:"varint,5,opt,name=RelayFromIp,proto3" json:"RelayFromIp,omitempty"` + OldRelayToAddr uint32 `protobuf:"varint,4,opt,name=OldRelayToAddr,proto3" json:"OldRelayToAddr,omitempty"` // Deprecated: Do not use. + OldRelayFromAddr uint32 `protobuf:"varint,5,opt,name=OldRelayFromAddr,proto3" json:"OldRelayFromAddr,omitempty"` // Deprecated: Do not use. + RelayToAddr *Addr `protobuf:"bytes,6,opt,name=RelayToAddr,proto3" json:"RelayToAddr,omitempty"` + RelayFromAddr *Addr `protobuf:"bytes,7,opt,name=RelayFromAddr,proto3" json:"RelayFromAddr,omitempty"` } func (m *NebulaControl) Reset() { *m = NebulaControl{} } func (m *NebulaControl) String() string { return proto.CompactTextString(m) } func (*NebulaControl) ProtoMessage() {} func (*NebulaControl) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{7} + return fileDescriptor_2d65afa7693df5ef, []int{8} } func (m *NebulaControl) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -609,28 +689,45 @@ func (m *NebulaControl) GetResponderRelayIndex() uint32 { return 0 } -func (m *NebulaControl) GetRelayToIp() uint32 { +// Deprecated: Do not use. +func (m *NebulaControl) GetOldRelayToAddr() uint32 { if m != nil { - return m.RelayToIp + return m.OldRelayToAddr } return 0 } -func (m *NebulaControl) GetRelayFromIp() uint32 { +// Deprecated: Do not use. +func (m *NebulaControl) GetOldRelayFromAddr() uint32 { if m != nil { - return m.RelayFromIp + return m.OldRelayFromAddr } return 0 } +func (m *NebulaControl) GetRelayToAddr() *Addr { + if m != nil { + return m.RelayToAddr + } + return nil +} + +func (m *NebulaControl) GetRelayFromAddr() *Addr { + if m != nil { + return m.RelayFromAddr + } + return nil +} + func init() { proto.RegisterEnum("nebula.NebulaMeta_MessageType", NebulaMeta_MessageType_name, NebulaMeta_MessageType_value) proto.RegisterEnum("nebula.NebulaPing_MessageType", NebulaPing_MessageType_name, NebulaPing_MessageType_value) proto.RegisterEnum("nebula.NebulaControl_MessageType", NebulaControl_MessageType_name, NebulaControl_MessageType_value) proto.RegisterType((*NebulaMeta)(nil), "nebula.NebulaMeta") proto.RegisterType((*NebulaMetaDetails)(nil), "nebula.NebulaMetaDetails") - proto.RegisterType((*Ip4AndPort)(nil), "nebula.Ip4AndPort") - proto.RegisterType((*Ip6AndPort)(nil), "nebula.Ip6AndPort") + proto.RegisterType((*Addr)(nil), "nebula.Addr") + proto.RegisterType((*V4AddrPort)(nil), "nebula.V4AddrPort") + proto.RegisterType((*V6AddrPort)(nil), "nebula.V6AddrPort") proto.RegisterType((*NebulaPing)(nil), "nebula.NebulaPing") proto.RegisterType((*NebulaHandshake)(nil), "nebula.NebulaHandshake") proto.RegisterType((*NebulaHandshakeDetails)(nil), "nebula.NebulaHandshakeDetails") @@ -640,52 +737,57 @@ func init() { func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) } var fileDescriptor_2d65afa7693df5ef = []byte{ - // 707 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x54, 0x4d, 0x6f, 0xda, 0x4a, - 0x14, 0xc5, 0xc6, 0x7c, 0x5d, 0x02, 0xf1, 0xbb, 0x79, 0x8f, 0x07, 0x4f, 0xaf, 0x16, 0xf5, 0xa2, - 0x62, 0x45, 0x22, 0x92, 0x46, 0x5d, 0x36, 0xa5, 0xaa, 0x20, 0x4a, 0x22, 0x3a, 0x4a, 0x5b, 0xa9, - 0x9b, 0x6a, 0x62, 0xa6, 0xc1, 0x02, 0x3c, 0x8e, 0x3d, 0x54, 0xe1, 0x5f, 0xf4, 0xc7, 0xe4, 0x47, - 0x74, 0xd7, 0x2c, 0xbb, 0xac, 0x92, 0x65, 0x97, 0xfd, 0x03, 0xd5, 0x8c, 0xc1, 0x36, 0x84, 0x76, - 0x37, 0xe7, 0xde, 0x73, 0x66, 0xce, 0x9c, 0xb9, 0x36, 0x6c, 0x79, 0xec, 0x62, 0x36, 0xa1, 0x6d, - 0x3f, 0xe0, 0x82, 0x63, 0x3e, 0x42, 0xf6, 0x0f, 0x1d, 0xe0, 0x4c, 0x2d, 0x4f, 0x99, 0xa0, 0xd8, - 0x01, 0xe3, 0x7c, 0xee, 0xb3, 0xba, 0xd6, 0xd4, 0x5a, 0xd5, 0x8e, 0xd5, 0x5e, 0x68, 0x12, 0x46, - 0xfb, 0x94, 0x85, 0x21, 0xbd, 0x64, 0x92, 0x45, 0x14, 0x17, 0xf7, 0xa1, 0xf0, 0x92, 0x09, 0xea, - 0x4e, 0xc2, 0xba, 0xde, 0xd4, 0x5a, 0xe5, 0x4e, 0xe3, 0xa1, 0x6c, 0x41, 0x20, 0x4b, 0xa6, 0xfd, - 0x53, 0x83, 0x72, 0x6a, 0x2b, 0x2c, 0x82, 0x71, 0xc6, 0x3d, 0x66, 0x66, 0xb0, 0x02, 0xa5, 0x1e, - 0x0f, 0xc5, 0xeb, 0x19, 0x0b, 0xe6, 0xa6, 0x86, 0x08, 0xd5, 0x18, 0x12, 0xe6, 0x4f, 0xe6, 0xa6, - 0x8e, 0xff, 0x41, 0x4d, 0xd6, 0xde, 0xf8, 0x43, 0x2a, 0xd8, 0x19, 0x17, 0xee, 0x47, 0xd7, 0xa1, - 0xc2, 0xe5, 0x9e, 0x99, 0xc5, 0x06, 0xfc, 0x23, 0x7b, 0xa7, 0xfc, 0x13, 0x1b, 0xae, 0xb4, 0x8c, - 0x65, 0x6b, 0x30, 0xf3, 0x9c, 0xd1, 0x4a, 0x2b, 0x87, 0x55, 0x00, 0xd9, 0x7a, 0x37, 0xe2, 0x74, - 0xea, 0x9a, 0x79, 0xdc, 0x81, 0xed, 0x04, 0x47, 0xc7, 0x16, 0xa4, 0xb3, 0x01, 0x15, 0xa3, 0xee, - 0x88, 0x39, 0x63, 0xb3, 0x28, 0x9d, 0xc5, 0x30, 0xa2, 0x94, 0xf0, 0x11, 0x34, 0x36, 0x3b, 0x3b, - 0x72, 0xc6, 0x26, 0xd8, 0x5f, 0x35, 0xf8, 0xeb, 0x41, 0x28, 0xf8, 0x37, 0xe4, 0xde, 0xfa, 0x5e, - 0xdf, 0x57, 0xa9, 0x57, 0x48, 0x04, 0xf0, 0x00, 0xca, 0x7d, 0xff, 0xe0, 0xc8, 0x1b, 0x0e, 0x78, - 0x20, 0x64, 0xb4, 0xd9, 0x56, 0xb9, 0x83, 0xcb, 0x68, 0x93, 0x16, 0x49, 0xd3, 0x22, 0xd5, 0x61, - 0xac, 0x32, 0xd6, 0x55, 0x87, 0x29, 0x55, 0x4c, 0x43, 0x0b, 0x80, 0xb0, 0x09, 0x9d, 0x47, 0x36, - 0x72, 0xcd, 0x6c, 0xab, 0x42, 0x52, 0x15, 0xac, 0x43, 0xc1, 0xe1, 0x33, 0x4f, 0xb0, 0xa0, 0x9e, - 0x55, 0x1e, 0x97, 0xd0, 0xde, 0x03, 0x48, 0x8e, 0xc7, 0x2a, 0xe8, 0xf1, 0x35, 0xf4, 0xbe, 0x8f, - 0x08, 0x86, 0xac, 0xab, 0xb9, 0xa8, 0x10, 0xb5, 0xb6, 0x9f, 0x4b, 0xc5, 0x61, 0x4a, 0xd1, 0x73, - 0x95, 0xc2, 0x20, 0x7a, 0xcf, 0x95, 0xf8, 0x84, 0x2b, 0xbe, 0x41, 0xf4, 0x13, 0x1e, 0xef, 0x90, - 0x4d, 0xed, 0x70, 0xbd, 0x1c, 0xd9, 0x81, 0xeb, 0x5d, 0xfe, 0x79, 0x64, 0x25, 0x63, 0xc3, 0xc8, - 0x22, 0x18, 0xe7, 0xee, 0x94, 0x2d, 0xce, 0x51, 0x6b, 0xdb, 0x7e, 0x30, 0x90, 0x52, 0x6c, 0x66, - 0xb0, 0x04, 0xb9, 0xe8, 0x79, 0x35, 0xfb, 0x03, 0x6c, 0x47, 0xfb, 0xf6, 0xa8, 0x37, 0x0c, 0x47, - 0x74, 0xcc, 0xf0, 0x59, 0x32, 0xfd, 0x9a, 0x9a, 0xfe, 0x35, 0x07, 0x31, 0x73, 0xfd, 0x13, 0x90, - 0x26, 0x7a, 0x53, 0xea, 0x28, 0x13, 0x5b, 0x44, 0xad, 0xed, 0x1b, 0x0d, 0x6a, 0x9b, 0x75, 0x92, - 0xde, 0x65, 0x81, 0x50, 0xa7, 0x6c, 0x11, 0xb5, 0xc6, 0x27, 0x50, 0xed, 0x7b, 0xae, 0x70, 0xa9, - 0xe0, 0x41, 0xdf, 0x1b, 0xb2, 0xeb, 0x45, 0xd2, 0x6b, 0x55, 0xc9, 0x23, 0x2c, 0xf4, 0xb9, 0x37, - 0x64, 0x0b, 0x5e, 0x94, 0xe7, 0x5a, 0x15, 0x6b, 0x90, 0xef, 0x72, 0x3e, 0x76, 0x59, 0xdd, 0x50, - 0xc9, 0x2c, 0x50, 0x9c, 0x57, 0x2e, 0xc9, 0xeb, 0xd8, 0x28, 0xe6, 0xcd, 0xc2, 0xb1, 0x51, 0x2c, - 0x98, 0x45, 0xfb, 0x46, 0x87, 0x4a, 0x64, 0xbb, 0xcb, 0x3d, 0x11, 0xf0, 0x09, 0x3e, 0x5d, 0x79, - 0x95, 0xc7, 0xab, 0x99, 0x2c, 0x48, 0x1b, 0x1e, 0x66, 0x0f, 0x76, 0x62, 0xeb, 0x6a, 0xfe, 0xd2, - 0xb7, 0xda, 0xd4, 0x92, 0x8a, 0xf8, 0x12, 0x29, 0x45, 0x74, 0xbf, 0x4d, 0x2d, 0xfc, 0x1f, 0x4a, - 0x0a, 0x9d, 0xf3, 0xbe, 0xaf, 0xee, 0x59, 0x21, 0x49, 0x01, 0x9b, 0x50, 0x56, 0xe0, 0x55, 0xc0, - 0xa7, 0xea, 0x5b, 0x90, 0xfd, 0x74, 0xc9, 0xee, 0xfd, 0xee, 0xcf, 0x55, 0x03, 0xec, 0x06, 0x8c, - 0x0a, 0xa6, 0xd8, 0x84, 0x5d, 0xcd, 0x58, 0x28, 0x4c, 0x0d, 0xff, 0x85, 0x9d, 0x95, 0xba, 0xb4, - 0x14, 0x32, 0x53, 0x7f, 0xb1, 0xff, 0xe5, 0xce, 0xd2, 0x6e, 0xef, 0x2c, 0xed, 0xfb, 0x9d, 0xa5, - 0x7d, 0xbe, 0xb7, 0x32, 0xb7, 0xf7, 0x56, 0xe6, 0xdb, 0xbd, 0x95, 0x79, 0xdf, 0xb8, 0x74, 0xc5, - 0x68, 0x76, 0xd1, 0x76, 0xf8, 0x74, 0x37, 0x9c, 0x50, 0x67, 0x3c, 0xba, 0xda, 0x8d, 0x22, 0xbc, - 0xc8, 0xab, 0x1f, 0xf8, 0xfe, 0xaf, 0x00, 0x00, 0x00, 0xff, 0xff, 0x17, 0x56, 0x28, 0x74, 0xd0, - 0x05, 0x00, 0x00, + // 785 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xcd, 0x6e, 0xeb, 0x44, + 0x14, 0x8e, 0x1d, 0x27, 0x4e, 0x4f, 0x7e, 0xae, 0x39, 0x15, 0xc1, 0x41, 0x22, 0x0a, 0x5e, 0x54, + 0x57, 0x2c, 0x72, 0x51, 0x5a, 0xae, 0x58, 0x72, 0x1b, 0x84, 0xd2, 0xaa, 0x3f, 0x61, 0x54, 0x8a, + 0xc4, 0x06, 0xb9, 0xf6, 0xd0, 0x58, 0x71, 0x3c, 0xa9, 0x3d, 0x41, 0xcd, 0x5b, 0xf0, 0x30, 0x3c, + 0x04, 0xec, 0xba, 0x42, 0x2c, 0x51, 0xbb, 0x64, 0xc9, 0x0b, 0xa0, 0x19, 0xff, 0x27, 0x86, 0xbb, + 0x9b, 0x73, 0xbe, 0xef, 0x3b, 0x73, 0xe6, 0xf3, 0x9c, 0x31, 0x74, 0x02, 0x7a, 0xb7, 0xf1, 0xed, + 0xf1, 0x3a, 0x64, 0x9c, 0x61, 0x33, 0x8e, 0xac, 0xbf, 0x55, 0x80, 0x2b, 0xb9, 0xbc, 0xa4, 0xdc, + 0xc6, 0x09, 0x68, 0x37, 0xdb, 0x35, 0x35, 0x95, 0x91, 0xf2, 0xba, 0x37, 0x19, 0x8e, 0x13, 0x4d, + 0xce, 0x18, 0x5f, 0xd2, 0x28, 0xb2, 0xef, 0xa9, 0x60, 0x11, 0xc9, 0xc5, 0x63, 0xd0, 0xbf, 0xa6, + 0xdc, 0xf6, 0xfc, 0xc8, 0x54, 0x47, 0xca, 0xeb, 0xf6, 0x64, 0xb0, 0x2f, 0x4b, 0x08, 0x24, 0x65, + 0x5a, 0xff, 0x28, 0xd0, 0x2e, 0x94, 0xc2, 0x16, 0x68, 0x57, 0x2c, 0xa0, 0x46, 0x0d, 0xbb, 0x70, + 0x30, 0x63, 0x11, 0xff, 0x76, 0x43, 0xc3, 0xad, 0xa1, 0x20, 0x42, 0x2f, 0x0b, 0x09, 0x5d, 0xfb, + 0x5b, 0x43, 0xc5, 0x8f, 0xa1, 0x2f, 0x72, 0xdf, 0xad, 0x5d, 0x9b, 0xd3, 0x2b, 0xc6, 0xbd, 0x9f, + 0x3c, 0xc7, 0xe6, 0x1e, 0x0b, 0x8c, 0x3a, 0x0e, 0xe0, 0x43, 0x81, 0x5d, 0xb2, 0x9f, 0xa9, 0x5b, + 0x82, 0xb4, 0x14, 0x9a, 0x6f, 0x02, 0x67, 0x51, 0x82, 0x1a, 0xd8, 0x03, 0x10, 0xd0, 0xf7, 0x0b, + 0x66, 0xaf, 0x3c, 0xa3, 0x89, 0x87, 0xf0, 0x2a, 0x8f, 0xe3, 0x6d, 0x75, 0xd1, 0xd9, 0xdc, 0xe6, + 0x8b, 0xe9, 0x82, 0x3a, 0x4b, 0xa3, 0x25, 0x3a, 0xcb, 0xc2, 0x98, 0x72, 0x80, 0x9f, 0xc0, 0xa0, + 0xba, 0xb3, 0x77, 0xce, 0xd2, 0x00, 0xeb, 0x77, 0x15, 0x3e, 0xd8, 0x33, 0x05, 0x2d, 0x80, 0x6b, + 0xdf, 0xbd, 0x5d, 0x07, 0xef, 0x5c, 0x37, 0x94, 0xd6, 0x77, 0x4f, 0x55, 0x53, 0x21, 0x85, 0x2c, + 0x1e, 0x81, 0x9e, 0x12, 0x9a, 0xd2, 0xe4, 0x4e, 0x6a, 0xb2, 0xc8, 0x91, 0x14, 0xc4, 0x31, 0x18, + 0xd7, 0xbe, 0x4b, 0xa8, 0x6f, 0x6f, 0x93, 0x54, 0x64, 0x36, 0x46, 0xf5, 0xa4, 0xe2, 0x1e, 0x86, + 0x13, 0xe8, 0x96, 0xc9, 0xfa, 0xa8, 0xbe, 0x57, 0xbd, 0x4c, 0xc1, 0x13, 0x68, 0xdf, 0x9e, 0x88, + 0xe5, 0x9c, 0x85, 0x5c, 0x7c, 0x74, 0xa1, 0xc0, 0x54, 0x91, 0x43, 0xa4, 0x48, 0x93, 0xaa, 0xb7, + 0xb9, 0x4a, 0xdb, 0x51, 0xbd, 0x2d, 0xa8, 0x72, 0x1a, 0x9a, 0xa0, 0x3b, 0x6c, 0x13, 0x70, 0x1a, + 0x9a, 0x75, 0x61, 0x0c, 0x49, 0x43, 0xeb, 0x08, 0x34, 0x79, 0xe2, 0x1e, 0xa8, 0x33, 0x4f, 0xba, + 0xa6, 0x11, 0x75, 0xe6, 0x89, 0xf8, 0x82, 0xc9, 0x9b, 0xa8, 0x11, 0xf5, 0x82, 0x59, 0x27, 0x00, + 0x79, 0x1b, 0x88, 0xb1, 0x2a, 0x76, 0x99, 0xc4, 0x15, 0x10, 0x34, 0x81, 0x49, 0x4d, 0x97, 0xc8, + 0xb5, 0xf5, 0x15, 0x40, 0xde, 0xc6, 0xfb, 0xf6, 0xc8, 0x2a, 0xd4, 0x0b, 0x15, 0x1e, 0xd3, 0xc1, + 0x9a, 0x7b, 0xc1, 0xfd, 0xff, 0x0f, 0x96, 0x60, 0x54, 0x0c, 0x16, 0x82, 0x76, 0xe3, 0xad, 0x68, + 0xb2, 0x8f, 0x5c, 0x5b, 0xd6, 0xde, 0xd8, 0x08, 0xb1, 0x51, 0xc3, 0x03, 0x68, 0xc4, 0x97, 0x50, + 0xb1, 0x7e, 0x84, 0x57, 0x71, 0xdd, 0x99, 0x1d, 0xb8, 0xd1, 0xc2, 0x5e, 0x52, 0xfc, 0x32, 0x9f, + 0x51, 0x45, 0x5e, 0x9f, 0x9d, 0x0e, 0x32, 0xe6, 0xee, 0xa0, 0x8a, 0x26, 0x66, 0x2b, 0xdb, 0x91, + 0x4d, 0x74, 0x88, 0x5c, 0x5b, 0x7f, 0x28, 0xd0, 0xaf, 0xd6, 0x09, 0xfa, 0x94, 0x86, 0x5c, 0xee, + 0xd2, 0x21, 0x72, 0x8d, 0x47, 0xd0, 0x3b, 0x0b, 0x3c, 0xee, 0xd9, 0x9c, 0x85, 0x67, 0x81, 0x4b, + 0x1f, 0x13, 0xa7, 0x77, 0xb2, 0x82, 0x47, 0x68, 0xb4, 0x66, 0x81, 0x4b, 0x13, 0x5e, 0xec, 0xe7, + 0x4e, 0x16, 0xfb, 0xd0, 0x9c, 0x32, 0xb6, 0xf4, 0xa8, 0xa9, 0x49, 0x67, 0x92, 0x28, 0xf3, 0xab, + 0x91, 0xfb, 0x85, 0x23, 0x68, 0x8b, 0x1e, 0x6e, 0x69, 0x18, 0x79, 0x2c, 0x30, 0x5b, 0xb2, 0x60, + 0x31, 0x75, 0xae, 0xb5, 0x9a, 0x86, 0x7e, 0xae, 0xb5, 0x74, 0xa3, 0x65, 0xfd, 0x5a, 0x87, 0x6e, + 0x7c, 0xb0, 0x29, 0x0b, 0x78, 0xc8, 0x7c, 0xfc, 0xa2, 0xf4, 0xdd, 0x3e, 0x2d, 0xbb, 0x96, 0x90, + 0x2a, 0x3e, 0xdd, 0xe7, 0x70, 0x98, 0x1d, 0x4e, 0x0e, 0x4f, 0xf1, 0xdc, 0x55, 0x90, 0x50, 0x64, + 0xc7, 0x2c, 0x28, 0x62, 0x07, 0xaa, 0x20, 0xfc, 0x0c, 0x7a, 0xe9, 0x38, 0xdf, 0x30, 0x79, 0xa9, + 0xb5, 0xec, 0xe9, 0xd8, 0x41, 0x8a, 0xcf, 0xc2, 0x37, 0x21, 0x5b, 0x49, 0x76, 0x23, 0x63, 0xef, + 0x61, 0x38, 0x86, 0x76, 0xb1, 0x70, 0xd5, 0x93, 0x53, 0x24, 0x64, 0xcf, 0x48, 0x56, 0x5c, 0xaf, + 0x50, 0x94, 0x29, 0xd6, 0xec, 0xbf, 0xfe, 0x00, 0x7d, 0xc0, 0x69, 0x48, 0x6d, 0x4e, 0x25, 0x9f, + 0xd0, 0x87, 0x0d, 0x8d, 0xb8, 0xa1, 0xe0, 0x47, 0x70, 0x58, 0xca, 0x0b, 0x4b, 0x22, 0x6a, 0xa8, + 0xa7, 0xc7, 0xbf, 0x3d, 0x0f, 0x95, 0xa7, 0xe7, 0xa1, 0xf2, 0xd7, 0xf3, 0x50, 0xf9, 0xe5, 0x65, + 0x58, 0x7b, 0x7a, 0x19, 0xd6, 0xfe, 0x7c, 0x19, 0xd6, 0x7e, 0x18, 0xdc, 0x7b, 0x7c, 0xb1, 0xb9, + 0x1b, 0x3b, 0x6c, 0xf5, 0x26, 0xf2, 0x6d, 0x67, 0xb9, 0x78, 0x78, 0x13, 0xb7, 0x74, 0xd7, 0x94, + 0x3f, 0xc2, 0xe3, 0x7f, 0x03, 0x00, 0x00, 0xff, 0xff, 0xea, 0x6f, 0xbc, 0x50, 0x18, 0x07, 0x00, + 0x00, } func (m *NebulaMeta) Marshal() (dAtA []byte, err error) { @@ -748,28 +850,54 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l - if len(m.RelayVpnIp) > 0 { - dAtA3 := make([]byte, len(m.RelayVpnIp)*10) - var j2 int - for _, num := range m.RelayVpnIp { - for num >= 1<<7 { - dAtA3[j2] = uint8(uint64(num)&0x7f | 0x80) - num >>= 7 - j2++ + if len(m.RelayVpnAddrs) > 0 { + for iNdEx := len(m.RelayVpnAddrs) - 1; iNdEx >= 0; iNdEx-- { + { + size, err := m.RelayVpnAddrs[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) } - dAtA3[j2] = uint8(num) - j2++ + i-- + dAtA[i] = 0x3a } - i -= j2 - copy(dAtA[i:], dAtA3[:j2]) - i = encodeVarintNebula(dAtA, i, uint64(j2)) + } + if m.VpnAddr != nil { + { + size, err := m.VpnAddr.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x32 + } + if len(m.OldRelayVpnAddrs) > 0 { + dAtA4 := make([]byte, len(m.OldRelayVpnAddrs)*10) + var j3 int + for _, num := range m.OldRelayVpnAddrs { + for num >= 1<<7 { + dAtA4[j3] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j3++ + } + dAtA4[j3] = uint8(num) + j3++ + } + i -= j3 + copy(dAtA[i:], dAtA4[:j3]) + i = encodeVarintNebula(dAtA, i, uint64(j3)) i-- dAtA[i] = 0x2a } - if len(m.Ip6AndPorts) > 0 { - for iNdEx := len(m.Ip6AndPorts) - 1; iNdEx >= 0; iNdEx-- { + if len(m.V6AddrPorts) > 0 { + for iNdEx := len(m.V6AddrPorts) - 1; iNdEx >= 0; iNdEx-- { { - size, err := m.Ip6AndPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + size, err := m.V6AddrPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } @@ -785,10 +913,10 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { i-- dAtA[i] = 0x18 } - if len(m.Ip4AndPorts) > 0 { - for iNdEx := len(m.Ip4AndPorts) - 1; iNdEx >= 0; iNdEx-- { + if len(m.V4AddrPorts) > 0 { + for iNdEx := len(m.V4AddrPorts) - 1; iNdEx >= 0; iNdEx-- { { - size, err := m.Ip4AndPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + size, err := m.V4AddrPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } @@ -799,15 +927,15 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { dAtA[i] = 0x12 } } - if m.VpnIp != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.VpnIp)) + if m.OldVpnAddr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.OldVpnAddr)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } -func (m *Ip4AndPort) Marshal() (dAtA []byte, err error) { +func (m *Addr) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) @@ -817,12 +945,45 @@ func (m *Ip4AndPort) Marshal() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *Ip4AndPort) MarshalTo(dAtA []byte) (int, error) { +func (m *Addr) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } -func (m *Ip4AndPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { +func (m *Addr) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Lo != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.Lo)) + i-- + dAtA[i] = 0x10 + } + if m.Hi != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.Hi)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *V4AddrPort) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *V4AddrPort) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *V4AddrPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int @@ -832,15 +993,15 @@ func (m *Ip4AndPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i-- dAtA[i] = 0x10 } - if m.Ip != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.Ip)) + if m.Addr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.Addr)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } -func (m *Ip6AndPort) Marshal() (dAtA []byte, err error) { +func (m *V6AddrPort) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) @@ -850,12 +1011,12 @@ func (m *Ip6AndPort) Marshal() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *Ip6AndPort) MarshalTo(dAtA []byte) (int, error) { +func (m *V6AddrPort) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } -func (m *Ip6AndPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { +func (m *V6AddrPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int @@ -973,6 +1134,11 @@ func (m *NebulaHandshakeDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) _ = i var l int _ = l + if m.CertVersion != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.CertVersion)) + i-- + dAtA[i] = 0x40 + } if m.Time != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Time)) i-- @@ -1023,13 +1189,37 @@ func (m *NebulaControl) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l - if m.RelayFromIp != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.RelayFromIp)) + if m.RelayFromAddr != nil { + { + size, err := m.RelayFromAddr.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x3a + } + if m.RelayToAddr != nil { + { + size, err := m.RelayToAddr.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x32 + } + if m.OldRelayFromAddr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.OldRelayFromAddr)) i-- dAtA[i] = 0x28 } - if m.RelayToIp != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.RelayToIp)) + if m.OldRelayToAddr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.OldRelayToAddr)) i-- dAtA[i] = 0x20 } @@ -1084,11 +1274,11 @@ func (m *NebulaMetaDetails) Size() (n int) { } var l int _ = l - if m.VpnIp != 0 { - n += 1 + sovNebula(uint64(m.VpnIp)) + if m.OldVpnAddr != 0 { + n += 1 + sovNebula(uint64(m.OldVpnAddr)) } - if len(m.Ip4AndPorts) > 0 { - for _, e := range m.Ip4AndPorts { + if len(m.V4AddrPorts) > 0 { + for _, e := range m.V4AddrPorts { l = e.Size() n += 1 + l + sovNebula(uint64(l)) } @@ -1096,30 +1286,55 @@ func (m *NebulaMetaDetails) Size() (n int) { if m.Counter != 0 { n += 1 + sovNebula(uint64(m.Counter)) } - if len(m.Ip6AndPorts) > 0 { - for _, e := range m.Ip6AndPorts { + if len(m.V6AddrPorts) > 0 { + for _, e := range m.V6AddrPorts { l = e.Size() n += 1 + l + sovNebula(uint64(l)) } } - if len(m.RelayVpnIp) > 0 { + if len(m.OldRelayVpnAddrs) > 0 { l = 0 - for _, e := range m.RelayVpnIp { + for _, e := range m.OldRelayVpnAddrs { l += sovNebula(uint64(e)) } n += 1 + sovNebula(uint64(l)) + l } + if m.VpnAddr != nil { + l = m.VpnAddr.Size() + n += 1 + l + sovNebula(uint64(l)) + } + if len(m.RelayVpnAddrs) > 0 { + for _, e := range m.RelayVpnAddrs { + l = e.Size() + n += 1 + l + sovNebula(uint64(l)) + } + } return n } -func (m *Ip4AndPort) Size() (n int) { +func (m *Addr) Size() (n int) { if m == nil { return 0 } var l int _ = l - if m.Ip != 0 { - n += 1 + sovNebula(uint64(m.Ip)) + if m.Hi != 0 { + n += 1 + sovNebula(uint64(m.Hi)) + } + if m.Lo != 0 { + n += 1 + sovNebula(uint64(m.Lo)) + } + return n +} + +func (m *V4AddrPort) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Addr != 0 { + n += 1 + sovNebula(uint64(m.Addr)) } if m.Port != 0 { n += 1 + sovNebula(uint64(m.Port)) @@ -1127,7 +1342,7 @@ func (m *Ip4AndPort) Size() (n int) { return n } -func (m *Ip6AndPort) Size() (n int) { +func (m *V6AddrPort) Size() (n int) { if m == nil { return 0 } @@ -1199,6 +1414,9 @@ func (m *NebulaHandshakeDetails) Size() (n int) { if m.Time != 0 { n += 1 + sovNebula(uint64(m.Time)) } + if m.CertVersion != 0 { + n += 1 + sovNebula(uint64(m.CertVersion)) + } return n } @@ -1217,11 +1435,19 @@ func (m *NebulaControl) Size() (n int) { if m.ResponderRelayIndex != 0 { n += 1 + sovNebula(uint64(m.ResponderRelayIndex)) } - if m.RelayToIp != 0 { - n += 1 + sovNebula(uint64(m.RelayToIp)) + if m.OldRelayToAddr != 0 { + n += 1 + sovNebula(uint64(m.OldRelayToAddr)) } - if m.RelayFromIp != 0 { - n += 1 + sovNebula(uint64(m.RelayFromIp)) + if m.OldRelayFromAddr != 0 { + n += 1 + sovNebula(uint64(m.OldRelayFromAddr)) + } + if m.RelayToAddr != nil { + l = m.RelayToAddr.Size() + n += 1 + l + sovNebula(uint64(l)) + } + if m.RelayFromAddr != nil { + l = m.RelayFromAddr.Size() + n += 1 + l + sovNebula(uint64(l)) } return n } @@ -1368,9 +1594,9 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { switch fieldNum { case 1: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field VpnIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldVpnAddr", wireType) } - m.VpnIp = 0 + m.OldVpnAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -1380,14 +1606,14 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.VpnIp |= uint32(b&0x7F) << shift + m.OldVpnAddr |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 2: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Ip4AndPorts", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field V4AddrPorts", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -1414,8 +1640,8 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - m.Ip4AndPorts = append(m.Ip4AndPorts, &Ip4AndPort{}) - if err := m.Ip4AndPorts[len(m.Ip4AndPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + m.V4AddrPorts = append(m.V4AddrPorts, &V4AddrPort{}) + if err := m.V4AddrPorts[len(m.V4AddrPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex @@ -1440,7 +1666,7 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } case 4: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Ip6AndPorts", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field V6AddrPorts", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -1467,8 +1693,8 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - m.Ip6AndPorts = append(m.Ip6AndPorts, &Ip6AndPort{}) - if err := m.Ip6AndPorts[len(m.Ip6AndPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + m.V6AddrPorts = append(m.V6AddrPorts, &V6AddrPort{}) + if err := m.V6AddrPorts[len(m.V6AddrPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex @@ -1489,7 +1715,7 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { break } } - m.RelayVpnIp = append(m.RelayVpnIp, v) + m.OldRelayVpnAddrs = append(m.OldRelayVpnAddrs, v) } else if wireType == 2 { var packedLen int for shift := uint(0); ; shift += 7 { @@ -1524,8 +1750,8 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } } elementCount = count - if elementCount != 0 && len(m.RelayVpnIp) == 0 { - m.RelayVpnIp = make([]uint32, 0, elementCount) + if elementCount != 0 && len(m.OldRelayVpnAddrs) == 0 { + m.OldRelayVpnAddrs = make([]uint32, 0, elementCount) } for iNdEx < postIndex { var v uint32 @@ -1543,11 +1769,81 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { break } } - m.RelayVpnIp = append(m.RelayVpnIp, v) + m.OldRelayVpnAddrs = append(m.OldRelayVpnAddrs, v) } } else { - return fmt.Errorf("proto: wrong wireType = %d for field RelayVpnIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldRelayVpnAddrs", wireType) } + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field VpnAddr", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.VpnAddr == nil { + m.VpnAddr = &Addr{} + } + if err := m.VpnAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RelayVpnAddrs", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.RelayVpnAddrs = append(m.RelayVpnAddrs, &Addr{}) + if err := m.RelayVpnAddrs[len(m.RelayVpnAddrs)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) @@ -1569,7 +1865,7 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } return nil } -func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { +func (m *Addr) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -1592,17 +1888,17 @@ func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: Ip4AndPort: wiretype end group for non-group") + return fmt.Errorf("proto: Addr: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: Ip4AndPort: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: Addr: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Ip", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field Hi", wireType) } - m.Ip = 0 + m.Hi = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -1612,7 +1908,95 @@ func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.Ip |= uint32(b&0x7F) << shift + m.Hi |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Lo", wireType) + } + m.Lo = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Lo |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipNebula(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthNebula + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *V4AddrPort) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: V4AddrPort: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: V4AddrPort: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Addr", wireType) + } + m.Addr = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Addr |= uint32(b&0x7F) << shift if b < 0x80 { break } @@ -1657,7 +2041,7 @@ func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { } return nil } -func (m *Ip6AndPort) Unmarshal(dAtA []byte) error { +func (m *V6AddrPort) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -1680,10 +2064,10 @@ func (m *Ip6AndPort) Unmarshal(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: Ip6AndPort: wiretype end group for non-group") + return fmt.Errorf("proto: V6AddrPort: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: Ip6AndPort: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: V6AddrPort: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: @@ -2111,6 +2495,25 @@ func (m *NebulaHandshakeDetails) Unmarshal(dAtA []byte) error { break } } + case 8: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field CertVersion", wireType) + } + m.CertVersion = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.CertVersion |= uint32(b&0x7F) << shift + if b < 0x80 { + break + } + } default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) @@ -2220,9 +2623,9 @@ func (m *NebulaControl) Unmarshal(dAtA []byte) error { } case 4: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field RelayToIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldRelayToAddr", wireType) } - m.RelayToIp = 0 + m.OldRelayToAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -2232,16 +2635,16 @@ func (m *NebulaControl) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.RelayToIp |= uint32(b&0x7F) << shift + m.OldRelayToAddr |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 5: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field RelayFromIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldRelayFromAddr", wireType) } - m.RelayFromIp = 0 + m.OldRelayFromAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -2251,11 +2654,83 @@ func (m *NebulaControl) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.RelayFromIp |= uint32(b&0x7F) << shift + m.OldRelayFromAddr |= uint32(b&0x7F) << shift if b < 0x80 { break } } + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RelayToAddr", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.RelayToAddr == nil { + m.RelayToAddr = &Addr{} + } + if err := m.RelayToAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RelayFromAddr", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.RelayFromAddr == nil { + m.RelayFromAddr = &Addr{} + } + if err := m.RelayFromAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) diff --git a/nebula.proto b/nebula.proto index 88e33b7..ea10233 100644 --- a/nebula.proto +++ b/nebula.proto @@ -23,19 +23,28 @@ message NebulaMeta { } message NebulaMetaDetails { - uint32 VpnIp = 1; - repeated Ip4AndPort Ip4AndPorts = 2; - repeated Ip6AndPort Ip6AndPorts = 4; - repeated uint32 RelayVpnIp = 5; + uint32 OldVpnAddr = 1 [deprecated = true]; + Addr VpnAddr = 6; + + repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true]; + repeated Addr RelayVpnAddrs = 7; + + repeated V4AddrPort V4AddrPorts = 2; + repeated V6AddrPort V6AddrPorts = 4; uint32 counter = 3; } -message Ip4AndPort { - uint32 Ip = 1; +message Addr { + uint64 Hi = 1; + uint64 Lo = 2; +} + +message V4AddrPort { + uint32 Addr = 1; uint32 Port = 2; } -message Ip6AndPort { +message V6AddrPort { uint64 Hi = 1; uint64 Lo = 2; uint32 Port = 3; @@ -62,6 +71,7 @@ message NebulaHandshakeDetails { uint32 ResponderIndex = 3; uint64 Cookie = 4; uint64 Time = 5; + uint32 CertVersion = 8; // reserved for WIP multiport reserved 6, 7; } @@ -76,6 +86,10 @@ message NebulaControl { uint32 InitiatorRelayIndex = 2; uint32 ResponderRelayIndex = 3; - uint32 RelayToIp = 4; - uint32 RelayFromIp = 5; + + uint32 OldRelayToAddr = 4 [deprecated = true]; + uint32 OldRelayFromAddr = 5 [deprecated = true]; + + Addr RelayToAddr = 6; + Addr RelayFromAddr = 7; } diff --git a/noise.go b/noise.go index 91ad2c0..57990a7 100644 --- a/noise.go +++ b/noise.go @@ -28,11 +28,11 @@ func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState { // EncryptDanger encrypts and authenticates a given payload. // // out is a destination slice to hold the output of the EncryptDanger operation. -// - ad is additional data, which will be authenticated and appended to out, but not encrypted. -// - plaintext is encrypted, authenticated and appended to out. -// - n is a nonce value which must never be re-used with this key. -// - nb is a buffer used for temporary storage in the implementation of this call, which should -// be re-used by callers to minimize garbage collection. +// - ad is additional data, which will be authenticated and appended to out, but not encrypted. +// - plaintext is encrypted, authenticated and appended to out. +// - n is a nonce value which must never be re-used with this key. +// - nb is a buffer used for temporary storage in the implementation of this call, which should +// be re-used by callers to minimize garbage collection. func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) { if s != nil { // TODO: Is this okay now that we have made messageCounter atomic? diff --git a/noiseutil/pkcs11.go b/noiseutil/pkcs11.go new file mode 100644 index 0000000..d1c7ba9 --- /dev/null +++ b/noiseutil/pkcs11.go @@ -0,0 +1,50 @@ +package noiseutil + +import ( + "crypto/ecdh" + "fmt" + "strings" + + "github.com/slackhq/nebula/pkclient" + + "github.com/flynn/noise" +) + +// DHP256PKCS11 is the NIST P-256 ECDH function +var DHP256PKCS11 noise.DHFunc = newNISTP11Curve("P256", ecdh.P256(), 32) + +type nistP11Curve struct { + nistCurve +} + +func newNISTP11Curve(name string, curve ecdh.Curve, byteLen int) nistP11Curve { + return nistP11Curve{ + newNISTCurve(name, curve, byteLen), + } +} + +func (c nistP11Curve) DH(privkey, pubkey []byte) ([]byte, error) { + //for this function "privkey" is actually a pkcs11 URI + pkStr := string(privkey) + + //to set up a handshake, we need to also do non-pkcs11-DH. Handle that here. + if !strings.HasPrefix(pkStr, "pkcs11:") { + return DHP256.DH(privkey, pubkey) + } + ecdhPubKey, err := c.curve.NewPublicKey(pubkey) + if err != nil { + return nil, fmt.Errorf("unable to unmarshal pubkey: %w", err) + } + + //this is not the most performant way to do this (a long-lived client would be better) + //but, it works, and helps avoid problems with stale sessions and HSMs used by multiple users. + client, err := pkclient.FromUrl(pkStr) + if err != nil { + return nil, err + } + defer func(client *pkclient.PKClient) { + _ = client.Close() + }(client) + + return client.DeriveNoise(ecdhPubKey.Bytes()) +} diff --git a/outside.go b/outside.go index 818e2ae..1e9cde1 100644 --- a/outside.go +++ b/outside.go @@ -3,61 +3,40 @@ package nebula import ( "encoding/binary" "errors" - "fmt" + "net/netip" "time" - "github.com/flynn/noise" + "github.com/google/gopacket/layers" + "golang.org/x/net/ipv6" + "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" "golang.org/x/net/ipv4" - "google.golang.org/protobuf/proto" ) const ( minFwPacketLen = 4 ) -func readOutsidePackets(f *Interface) udp.EncReader { - return func( - addr *udp.Addr, - out []byte, - packet []byte, - header *header.H, - fwPacket *firewall.Packet, - lhh udp.LightHouseHandlerFunc, - nb []byte, - q int, - localCache firewall.ConntrackCache, - ) { - f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache) - } -} - -func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { - // TODO: best if we return this and let caller log - // TODO: Might be better to send the literal []byte("holepunch") packet and ignore that? // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors if len(packet) > 1 { - f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err) + f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err) } return } //l.Error("in packet ", header, packet[HeaderLen:]) - if addr != nil { - if ip4 := addr.IP.To4(); ip4 != nil { - if ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, iputil.VpnIp(binary.BigEndian.Uint32(ip4))) { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet") - } - return + if ip.IsValid() { + _, found := f.myVpnNetworksTable.Lookup(ip.Addr()) + if found { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") } + return } } @@ -77,7 +56,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt switch h.Type { case header.Message: // TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case. - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } @@ -101,7 +80,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt // Successfully validated the thing. Get rid of the Relay header. signedPayload = signedPayload[header.Len:] // Pull the Roaming parts up here, and return in all call paths. - f.handleHostRoaming(hostinfo, addr) + f.handleHostRoaming(hostinfo, ip) // Track usage of both the HostInfo and the Relay for the received & authenticated packet f.connectionManager.In(hostinfo.localIndexId) f.connectionManager.RelayUsed(h.RemoteIndex) @@ -110,7 +89,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt if !ok { // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing // its internal mapping. This should never happen. - hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") + hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") return } @@ -118,13 +97,13 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case TerminalType: // If I am the target of this relay, process the unwrapped packet // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - f.readOutsidePackets(nil, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) return case ForwardingType: // Find the target HostInfo relay object - targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp) + targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) if err != nil { - hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip") + hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") return } @@ -140,7 +119,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") } } else { - hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") + hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") return } } @@ -148,46 +127,40 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case header.LightHouse: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt lighthouse packet") - - //TODO: maybe after build 64 is out? 06/14/2018 - NB - //f.sendRecvError(net.Addr(addr), header.RemoteIndex) return } - lhf(addr, hostinfo.vpnIp, d) + lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f) // Fallthrough to the bottom to record incoming traffic case header.Test: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt test packet") - - //TODO: maybe after build 64 is out? 06/14/2018 - NB - //f.sendRecvError(net.Addr(addr), header.RemoteIndex) return } if h.Subtype == header.TestRequest { // This testRequest might be from TryPromoteBest, so we should roam // to the new IP address before responding - f.handleHostRoaming(hostinfo, addr) + f.handleHostRoaming(hostinfo, ip) f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) } @@ -198,54 +171,48 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case header.Handshake: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handshakeManager.HandleIncoming(addr, via, packet, h) + f.handshakeManager.HandleIncoming(ip, via, packet, h) return case header.RecvError: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handleRecvError(addr, h) + f.handleRecvError(ip, h) return case header.CloseTunnel: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } - hostinfo.logger(f.l).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", ip). Info("Close tunnel received, tearing down.") f.closeTunnel(hostinfo) return case header.Control: - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt Control packet") return } - m := &NebulaControl{} - err = m.Unmarshal(d) - if err != nil { - hostinfo.logger(f.l).WithError(err).Error("Failed to unmarshal control message") - break - } - f.relayManager.HandleControlMsg(hostinfo, m, f) + f.relayManager.HandleControlMsg(hostinfo, d, f) default: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr) + hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip) return } - f.handleHostRoaming(hostinfo, addr) + f.handleHostRoaming(hostinfo, ip) f.connectionManager.In(hostinfo.localIndexId) } @@ -254,8 +221,8 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt func (f *Interface) closeTunnel(hostInfo *HostInfo) { final := f.hostMap.DeleteHostInfo(hostInfo) if final { - // We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage - f.lightHouse.DeleteVpnIp(hostInfo.vpnIp) + // We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage + f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs) } } @@ -264,34 +231,35 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) } -func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) { - if addr != nil && !hostinfo.remote.Equals(addr) { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { - hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming") +func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) { + if udpAddr.IsValid() && hostinfo.remote != udpAddr { + if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) { + hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming") return } - if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + + if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). Info("Host roamed to new udp ip/port.") hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote - hostinfo.SetRemote(addr) + hostinfo.SetRemote(udpAddr) } } -func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool { +func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool { // If connectionstate exists and the replay protector allows, process packet // Else, send recv errors for 300 seconds after a restart to allow fast reconnection. if ci == nil || !ci.window.Check(f.l, h.MessageCounter) { - if addr != nil { + if addr.IsValid() { f.maybeSendRecvError(addr, h.RemoteIndex) return false } else { @@ -302,24 +270,141 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *head return true } +var ( + ErrPacketTooShort = errors.New("packet is too short") + ErrUnknownIPVersion = errors.New("packet is an unknown ip version") + ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length") + ErrIPv4PacketTooShort = errors.New("ipv4 packet is too short") + ErrIPv6PacketTooShort = errors.New("ipv6 packet is too short") + ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet") +) + // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { - // Do we at least have an ipv4 header worth of data? - if len(data) < ipv4.HeaderLen { - return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen) + if len(data) < 1 { + return ErrPacketTooShort } - // Is it an ipv4 packet? - if int((data[0]>>4)&0x0f) != 4 { - return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f)) + version := int((data[0] >> 4) & 0x0f) + switch version { + case ipv4.Version: + return parseV4(data, incoming, fp) + case ipv6.Version: + return parseV6(data, incoming, fp) + } + return ErrUnknownIPVersion +} + +func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { + dataLen := len(data) + if dataLen < ipv6.HeaderLen { + return ErrIPv6PacketTooShort + } + + if incoming { + fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24]) + fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40]) + } else { + fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24]) + fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40]) + } + + protoAt := 6 // NextHeader is at 6 bytes into the ipv6 header + offset := ipv6.HeaderLen // Start at the end of the ipv6 header + next := 0 + for { + if dataLen < offset { + break + } + + proto := layers.IPProtocol(data[protoAt]) + //fmt.Println(proto, protoAt) + switch proto { + case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader: + fp.Protocol = uint8(proto) + fp.RemotePort = 0 + fp.LocalPort = 0 + fp.Fragment = false + return nil + + case layers.IPProtocolTCP, layers.IPProtocolUDP: + if dataLen < offset+4 { + return ErrIPv6PacketTooShort + } + + fp.Protocol = uint8(proto) + if incoming { + fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + } else { + fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + } + + fp.Fragment = false + return nil + + case layers.IPProtocolIPv6Fragment: + // Fragment header is 8 bytes, need at least offset+4 to read the offset field + if dataLen < offset+8 { + return ErrIPv6PacketTooShort + } + + // Check if this is the first fragment + fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) // Remove the reserved and M flag bits + if fragmentOffset != 0 { + // Non-first fragment, use what we have now and stop processing + fp.Protocol = data[offset] + fp.Fragment = true + fp.RemotePort = 0 + fp.LocalPort = 0 + return nil + } + + // The next loop should be the transport layer since we are the first fragment + next = 8 // Fragment headers are always 8 bytes + + case layers.IPProtocolAH: + // Auth headers, used by IPSec, have a different meaning for header length + if dataLen < offset+1 { + break + } + + next = int(data[offset+1]+2) << 2 + + default: + // Normal ipv6 header length processing + if dataLen < offset+1 { + break + } + + next = int(data[offset+1]+1) << 3 + } + + if next <= 0 { + // Safety check, each ipv6 header has to be at least 8 bytes + next = 8 + } + + protoAt = offset + offset = offset + next + } + + return ErrIPv6CouldNotFindPayload +} + +func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { + // Do we at least have an ipv4 header worth of data? + if len(data) < ipv4.HeaderLen { + return ErrIPv4PacketTooShort } // Adjust our start position based on the advertised ip header length ihl := int(data[0]&0x0f) << 2 - // Well formed ip header length? + // Well-formed ip header length? if ihl < ipv4.HeaderLen { - return fmt.Errorf("packet had an invalid header length: %v", ihl) + return ErrIPv4InvalidHeaderLength } // Check if this is the second or further fragment of a fragmented packet. @@ -335,13 +420,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { minLen += minFwPacketLen } if len(data) < minLen { - return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl) + return ErrIPv4InvalidHeaderLength } // Firewall packets are locally oriented if incoming { - fp.RemoteIP = iputil.Ip2VpnIp(data[12:16]) - fp.LocalIP = iputil.Ip2VpnIp(data[16:20]) + fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16]) + fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -350,8 +435,8 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) } } else { - fp.LocalIP = iputil.Ip2VpnIp(data[12:16]) - fp.RemoteIP = iputil.Ip2VpnIp(data[16:20]) + fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16]) + fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -386,8 +471,6 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) if err != nil { hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") - //TODO: maybe after build 64 is out? 06/14/2018 - NB - //f.sendRecvError(hostinfo.remote, header.RemoteIndex) return false } @@ -425,18 +508,17 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return true } -func (f *Interface) maybeSendRecvError(endpoint *udp.Addr, index uint32) { - if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint.IP) { +func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) { + if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint) { f.sendRecvError(endpoint, index) } } -func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) { +func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { f.messageMetrics.Tx(header.RecvError, 0, 1) - //TODO: this should be a signed message so we can trust that we should drop the index b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0) - f.outside.WriteTo(b, endpoint) + _ = f.outside.WriteTo(b, endpoint) if f.l.Level >= logrus.DebugLevel { f.l.WithField("index", index). WithField("udpAddr", endpoint). @@ -444,7 +526,7 @@ func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) { } } -func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { +func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { if f.l.Level >= logrus.DebugLevel { f.l.WithField("index", h.RemoteIndex). WithField("udpAddr", addr). @@ -461,7 +543,7 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { return } - if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) { + if hostinfo.remote.IsValid() && hostinfo.remote != addr { f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) return } @@ -470,65 +552,3 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { // We also delete it from pending hostmap to allow for fast reconnect. f.handshakeManager.DeleteHostInfo(hostinfo) } - -/* -func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *NebulaMeta) { - if ci.eKey != nil { - //TODO: log error? - return - } - - msg, err := proto.Marshal(meta) - if err != nil { - l.Debugln("failed to encode header") - } - - c := ci.messageCounter - b := HeaderEncode(nil, Version, uint8(metadata), 0, hostinfo.remoteIndexId, c) - ci.messageCounter++ - - msg := ci.eKey.EncryptDanger(b, nil, msg, c) - //msg := ci.eKey.EncryptDanger(b, nil, []byte(fmt.Sprintf("%d", counter)), c) - f.outside.WriteTo(msg, endpoint) -} -*/ - -func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.NebulaCAPool) (*cert.NebulaCertificate, error) { - pk := h.PeerStatic() - - if pk == nil { - return nil, errors.New("no peer static key was present") - } - - if rawCertBytes == nil { - return nil, errors.New("provided payload was empty") - } - - r := &cert.RawNebulaCertificate{} - err := proto.Unmarshal(rawCertBytes, r) - if err != nil { - return nil, fmt.Errorf("error unmarshaling cert: %s", err) - } - - // If the Details are nil, just exit to avoid crashing - if r.Details == nil { - return nil, fmt.Errorf("certificate did not contain any details") - } - - r.Details.PublicKey = pk - recombined, err := proto.Marshal(r) - if err != nil { - return nil, fmt.Errorf("error while recombining certificate: %s", err) - } - - c, _ := cert.UnmarshalNebulaCertificate(recombined) - isValid, err := c.Verify(time.Now(), caPool) - if err != nil { - return c, fmt.Errorf("certificate validation failed: %s", err) - } else if !isValid { - // This case should never happen but here's to defensive programming! - return c, errors.New("certificate validation failed but did not return an error") - } - - return c, nil -} diff --git a/outside_test.go b/outside_test.go index 682107b..c63e57d 100644 --- a/outside_test.go +++ b/outside_test.go @@ -1,21 +1,33 @@ package nebula import ( + "bytes" + "encoding/binary" "net" + "net/netip" "testing" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/net/ipv4" ) func Test_newPacket(t *testing.T) { p := &firewall.Packet{} - // length fail - err := newPacket([]byte{0, 1}, true, p) - assert.EqualError(t, err, "packet is less than 20 bytes") + // length fails + err := newPacket([]byte{}, true, p) + require.ErrorIs(t, err, ErrPacketTooShort) + + err = newPacket([]byte{0x40}, true, p) + require.ErrorIs(t, err, ErrIPv4PacketTooShort) + + err = newPacket([]byte{0x60}, true, p) + require.ErrorIs(t, err, ErrIPv6PacketTooShort) // length fail with ip options h := ipv4.Header{ @@ -28,16 +40,15 @@ func Test_newPacket(t *testing.T) { b, _ := h.Marshal() err = newPacket(b, true, p) - - assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24") + require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) // not an ipv4 packet err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.EqualError(t, err, "packet is not ipv4, type: 0") + require.ErrorIs(t, err, ErrUnknownIPVersion) // invalid ihl err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.EqualError(t, err, "packet had an invalid header length: 8") + require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) // account for variable ip header length - incoming h = ipv4.Header{ @@ -53,12 +64,13 @@ func Test_newPacket(t *testing.T) { b = append(b, []byte{0, 3, 0, 4}...) err = newPacket(b, true, p) - assert.Nil(t, err) - assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP)) - assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2))) - assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1))) - assert.Equal(t, p.RemotePort, uint16(3)) - assert.Equal(t, p.LocalPort, uint16(4)) + require.NoError(t, err) + assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr) + assert.Equal(t, uint16(3), p.RemotePort) + assert.Equal(t, uint16(4), p.LocalPort) + assert.False(t, p.Fragment) // account for variable ip header length - outgoing h = ipv4.Header{ @@ -74,10 +86,507 @@ func Test_newPacket(t *testing.T) { b = append(b, []byte{0, 5, 0, 6}...) err = newPacket(b, false, p) - assert.Nil(t, err) - assert.Equal(t, p.Protocol, uint8(2)) - assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1))) - assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2))) - assert.Equal(t, p.RemotePort, uint16(6)) - assert.Equal(t, p.LocalPort, uint16(5)) + require.NoError(t, err) + assert.Equal(t, uint8(2), p.Protocol) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr) + assert.Equal(t, uint16(6), p.RemotePort) + assert.Equal(t, uint16(5), p.LocalPort) + assert.False(t, p.Fragment) +} + +func Test_newPacket_v6(t *testing.T) { + p := &firewall.Packet{} + + // invalid ipv6 + ip := layers.IPv6{ + Version: 6, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + buffer := gopacket.NewSerializeBuffer() + opt := gopacket.SerializeOptions{ + ComputeChecksums: false, + FixLengths: false, + } + err := gopacket.SerializeLayers(buffer, opt, &ip) + require.NoError(t, err) + + err = newPacket(buffer.Bytes(), true, p) + require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + + // A good ICMP packet + ip = layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolICMPv6, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + icmp := layers.ICMPv6{} + + buffer.Clear() + err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp) + if err != nil { + panic(err) + } + + err = newPacket(buffer.Bytes(), true, p) + require.NoError(t, err) + assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0), p.LocalPort) + assert.False(t, p.Fragment) + + // A good ESP packet + b := buffer.Bytes() + b[6] = byte(layers.IPProtocolESP) + err = newPacket(b, true, p) + require.NoError(t, err) + assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0), p.LocalPort) + assert.False(t, p.Fragment) + + // A good None packet + b = buffer.Bytes() + b[6] = byte(layers.IPProtocolNoNextHeader) + err = newPacket(b, true, p) + require.NoError(t, err) + assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0), p.LocalPort) + assert.False(t, p.Fragment) + + // An unknown protocol packet + b = buffer.Bytes() + b[6] = 255 // 255 is a reserved protocol number + err = newPacket(b, true, p) + require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + + // A good UDP packet + ip = layers.IPv6{ + Version: 6, + NextHeader: firewall.ProtoUDP, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + udp := layers.UDP{ + SrcPort: layers.UDPPort(36123), + DstPort: layers.UDPPort(22), + } + err = udp.SetNetworkLayerForChecksum(&ip) + require.NoError(t, err) + + buffer.Clear() + err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef})) + if err != nil { + panic(err) + } + b = buffer.Bytes() + + // incoming + err = newPacket(b, true, p) + require.NoError(t, err) + assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(36123), p.RemotePort) + assert.Equal(t, uint16(22), p.LocalPort) + assert.False(t, p.Fragment) + + // outgoing + err = newPacket(b, false, p) + require.NoError(t, err) + assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) + assert.Equal(t, uint16(36123), p.LocalPort) + assert.Equal(t, uint16(22), p.RemotePort) + assert.False(t, p.Fragment) + + // Too short UDP packet + err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes + require.ErrorIs(t, err, ErrIPv6PacketTooShort) + + // A good TCP packet + b[6] = byte(layers.IPProtocolTCP) + + // incoming + err = newPacket(b, true, p) + require.NoError(t, err) + assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(36123), p.RemotePort) + assert.Equal(t, uint16(22), p.LocalPort) + assert.False(t, p.Fragment) + + // outgoing + err = newPacket(b, false, p) + require.NoError(t, err) + assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) + assert.Equal(t, uint16(36123), p.LocalPort) + assert.Equal(t, uint16(22), p.RemotePort) + assert.False(t, p.Fragment) + + // Too short TCP packet + err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes + require.ErrorIs(t, err, ErrIPv6PacketTooShort) + + // A good UDP packet with an AH header + ip = layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolAH, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + ah := layers.IPSecAH{ + AuthenticationData: []byte{0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef}, + } + ah.NextHeader = layers.IPProtocolUDP + + udpHeader := []byte{ + 0x8d, 0x1b, // Source port 36123 + 0x00, 0x16, // Destination port 22 + 0x00, 0x00, // Length + 0x00, 0x00, // Checksum + } + + buffer.Clear() + err = ip.SerializeTo(buffer, opt) + if err != nil { + panic(err) + } + + b = buffer.Bytes() + ahb := serializeAH(&ah) + b = append(b, ahb...) + b = append(b, udpHeader...) + + err = newPacket(b, true, p) + require.NoError(t, err) + assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(36123), p.RemotePort) + assert.Equal(t, uint16(22), p.LocalPort) + assert.False(t, p.Fragment) + + // Invalid AH header + b = buffer.Bytes() + err = newPacket(b, true, p) + require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) +} + +func Test_newPacket_ipv6Fragment(t *testing.T) { + p := &firewall.Packet{} + + ip := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolIPv6Fragment, + HopLimit: 64, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + // First fragment + fragHeader1 := []byte{ + uint8(layers.IPProtocolUDP), // Next Header (UDP) + 0x00, // Reserved + 0x00, // Fragment Offset high byte (0) + 0x01, // Fragment Offset low byte & flags (M=1) + 0x00, 0x00, 0x00, 0x01, // Identification + } + + udpHeader := []byte{ + 0x8d, 0x1b, // Source port 36123 + 0x00, 0x16, // Destination port 22 + 0x00, 0x00, // Length + 0x00, 0x00, // Checksum + } + + buffer := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + err := ip.SerializeTo(buffer, opts) + if err != nil { + t.Fatal(err) + } + + firstFrag := buffer.Bytes() + firstFrag = append(firstFrag, fragHeader1...) + firstFrag = append(firstFrag, udpHeader...) + firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) + + // Test first fragment incoming + err = newPacket(firstFrag, true, p) + require.NoError(t, err) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) + assert.Equal(t, uint16(36123), p.RemotePort) + assert.Equal(t, uint16(22), p.LocalPort) + assert.False(t, p.Fragment) + + // Test first fragment outgoing + err = newPacket(firstFrag, false, p) + require.NoError(t, err) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) + assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) + assert.Equal(t, uint16(36123), p.LocalPort) + assert.Equal(t, uint16(22), p.RemotePort) + assert.False(t, p.Fragment) + + // Second fragment + fragHeader2 := []byte{ + uint8(layers.IPProtocolUDP), // Next Header (UDP) + 0x00, // Reserved + 0xb9, // Fragment Offset high byte (185) + 0x01, // Fragment Offset low byte & flags (M=1) + 0x00, 0x00, 0x00, 0x01, // Identification + } + + buffer.Clear() + err = ip.SerializeTo(buffer, opts) + if err != nil { + t.Fatal(err) + } + + secondFrag := buffer.Bytes() + secondFrag = append(secondFrag, fragHeader2...) + secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) + + // Test second fragment incoming + err = newPacket(secondFrag, true, p) + require.NoError(t, err) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) + assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0), p.LocalPort) + assert.True(t, p.Fragment) + + // Test second fragment outgoing + err = newPacket(secondFrag, false, p) + require.NoError(t, err) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) + assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) + assert.Equal(t, uint16(0), p.LocalPort) + assert.Equal(t, uint16(0), p.RemotePort) + assert.True(t, p.Fragment) + + // Too short of a fragment packet + err = newPacket(secondFrag[:len(secondFrag)-10], false, p) + require.ErrorIs(t, err, ErrIPv6PacketTooShort) +} + +func BenchmarkParseV6(b *testing.B) { + // Regular UDP packet + ip := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolUDP, + HopLimit: 64, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + udp := &layers.UDP{ + SrcPort: layers.UDPPort(36123), + DstPort: layers.UDPPort(22), + } + + buffer := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: false, + FixLengths: true, + } + + err := gopacket.SerializeLayers(buffer, opts, ip, udp) + if err != nil { + b.Fatal(err) + } + normalPacket := buffer.Bytes() + + // First Fragment packet + ipFrag := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolIPv6Fragment, + HopLimit: 64, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + fragHeader := []byte{ + uint8(layers.IPProtocolUDP), // Next Header (UDP) + 0x00, // Reserved + 0x00, // Fragment Offset high byte (0) + 0x01, // Fragment Offset low byte & flags (M=1) + 0x00, 0x00, 0x00, 0x01, // Identification + } + + udpHeader := []byte{ + 0x8d, 0x7b, // Source port 36123 + 0x00, 0x16, // Destination port 22 + 0x00, 0x00, // Length + 0x00, 0x00, // Checksum + } + + buffer.Clear() + err = ipFrag.SerializeTo(buffer, opts) + if err != nil { + b.Fatal(err) + } + + firstFrag := buffer.Bytes() + firstFrag = append(firstFrag, fragHeader...) + firstFrag = append(firstFrag, udpHeader...) + firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) + + // Second Fragment packet + fragHeader[2] = 0xb9 // offset 185 + buffer.Clear() + err = ipFrag.SerializeTo(buffer, opts) + if err != nil { + b.Fatal(err) + } + + secondFrag := buffer.Bytes() + secondFrag = append(secondFrag, fragHeader...) + secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) + + fp := &firewall.Packet{} + + b.Run("Normal", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err = parseV6(normalPacket, true, fp); err != nil { + b.Fatal(err) + } + } + }) + + b.Run("FirstFragment", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err = parseV6(firstFrag, true, fp); err != nil { + b.Fatal(err) + } + } + }) + + b.Run("SecondFragment", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err = parseV6(secondFrag, true, fp); err != nil { + b.Fatal(err) + } + } + }) + + // Evil packet + evilPacket := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolIPv6HopByHop, + HopLimit: 64, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + hopHeader := []byte{ + uint8(layers.IPProtocolIPv6HopByHop), // Next Header (HopByHop) + 0x00, // Length + 0x00, 0x00, // Options and padding + 0x00, 0x00, 0x00, 0x00, // More options and padding + } + + lastHopHeader := []byte{ + uint8(layers.IPProtocolUDP), // Next Header (UDP) + 0x00, // Length + 0x00, 0x00, // Options and padding + 0x00, 0x00, 0x00, 0x00, // More options and padding + } + + buffer.Clear() + err = evilPacket.SerializeTo(buffer, opts) + if err != nil { + b.Fatal(err) + } + + evilBytes := buffer.Bytes() + for i := 0; i < 200; i++ { + evilBytes = append(evilBytes, hopHeader...) + } + evilBytes = append(evilBytes, lastHopHeader...) + evilBytes = append(evilBytes, udpHeader...) + evilBytes = append(evilBytes, []byte{0xde, 0xad, 0xbe, 0xef}...) + + b.Run("200 HopByHop headers", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err = parseV6(evilBytes, false, fp); err != nil { + b.Fatal(err) + } + } + }) +} + +// Ensure authentication data is a multiple of 8 bytes by padding if necessary +func padAuthData(authData []byte) []byte { + // Length of Authentication Data must be a multiple of 8 bytes + paddingLength := (8 - (len(authData) % 8)) % 8 // Only pad if necessary + if paddingLength > 0 { + authData = append(authData, make([]byte, paddingLength)...) + } + return authData +} + +// Custom function to manually serialize IPSecAH for both IPv4 and IPv6 +func serializeAH(ah *layers.IPSecAH) []byte { + buf := new(bytes.Buffer) + + // Ensure Authentication Data is a multiple of 8 bytes + ah.AuthenticationData = padAuthData(ah.AuthenticationData) + // Calculate Payload Length (in 32-bit words, minus 2) + payloadLen := uint8((12+len(ah.AuthenticationData))/4) - 2 + + // Serialize fields + if err := binary.Write(buf, binary.BigEndian, ah.NextHeader); err != nil { + panic(err) + } + if err := binary.Write(buf, binary.BigEndian, payloadLen); err != nil { + panic(err) + } + if err := binary.Write(buf, binary.BigEndian, ah.Reserved); err != nil { + panic(err) + } + if err := binary.Write(buf, binary.BigEndian, ah.SPI); err != nil { + panic(err) + } + if err := binary.Write(buf, binary.BigEndian, ah.Seq); err != nil { + panic(err) + } + if len(ah.AuthenticationData) > 0 { + if err := binary.Write(buf, binary.BigEndian, ah.AuthenticationData); err != nil { + panic(err) + } + } + + return buf.Bytes() } diff --git a/overlay/device.go b/overlay/device.go index 3f3f2eb..07146ab 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -2,16 +2,16 @@ package overlay import ( "io" - "net" + "net/netip" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" ) type Device interface { io.ReadWriteCloser Activate() error - Cidr() *net.IPNet + Networks() []netip.Prefix Name() string - RouteFor(iputil.VpnIp) iputil.VpnIp + RoutesFor(netip.Addr) routing.Gateways NewMultiQueueReader() (io.ReadWriteCloser, error) } diff --git a/overlay/route.go b/overlay/route.go index 64c624c..6198958 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -1,34 +1,31 @@ package overlay import ( - "bytes" "fmt" "math" "net" + "net/netip" "runtime" "strconv" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" ) type Route struct { MTU int Metric int - Cidr *net.IPNet - Via *iputil.VpnIp + Cidr netip.Prefix + Via routing.Gateways Install bool } // Equal determines if a route that could be installed in the system route table is equal to another // Via is ignored since that is only consumed within nebula itself func (r Route) Equal(t Route) bool { - if !r.Cidr.IP.Equal(t.Cidr.IP) { - return false - } - if !bytes.Equal(r.Cidr.Mask, t.Cidr.Mask) { + if r.Cidr != t.Cidr { return false } if r.Metric != t.Metric { @@ -51,21 +48,23 @@ func (r Route) String() string { return s } -func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) { - routeTree := cidr.NewTree4[iputil.VpnIp]() +func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) { + routeTree := new(bart.Table[routing.Gateways]) for _, r := range routes { if !allowMTU && r.MTU > 0 { l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) } - if r.Via != nil { - routeTree.AddCIDR(r.Cidr, *r.Via) + gateways := r.Via + if len(gateways) > 0 { + routing.CalculateBucketsForGateways(gateways) + routeTree.Insert(r.Cidr, gateways) } } return routeTree, nil } -func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { +func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.routes") @@ -73,7 +72,7 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return []Route{}, nil } - rawRoutes, ok := r.([]interface{}) + rawRoutes, ok := r.([]any) if !ok { return nil, fmt.Errorf("tun.routes is not an array") } @@ -84,7 +83,7 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { routes := make([]Route, len(rawRoutes)) for i, r := range rawRoutes { - m, ok := r.(map[interface{}]interface{}) + m, ok := r.(map[string]any) if !ok { return nil, fmt.Errorf("entry %v in tun.routes is invalid", i+1) } @@ -116,17 +115,25 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { MTU: mtu, } - _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) + r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute)) if err != nil { return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err) } - if !ipWithin(network, r.Cidr) { + found := false + for _, network := range networks { + if network.Contains(r.Cidr.Addr()) && r.Cidr.Bits() >= network.Bits() { + found = true + break + } + } + + if !found { return nil, fmt.Errorf( - "entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v", + "entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, networks: %v", i+1, r.Cidr.String(), - network.String(), + networks, ) } @@ -136,7 +143,7 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return routes, nil } -func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { +func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.unsafe_routes") @@ -144,7 +151,7 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return []Route{}, nil } - rawRoutes, ok := r.([]interface{}) + rawRoutes, ok := r.([]any) if !ok { return nil, fmt.Errorf("tun.unsafe_routes is not an array") } @@ -155,7 +162,7 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { routes := make([]Route, len(rawRoutes)) for i, r := range rawRoutes { - m, ok := r.(map[interface{}]interface{}) + m, ok := r.(map[string]any) if !ok { return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1) } @@ -197,14 +204,63 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1) } - via, ok := rVia.(string) - if !ok { - return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia) - } + var gateways routing.Gateways - nVia := net.ParseIP(via) - if nVia == nil { - return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, via) + switch via := rVia.(type) { + case string: + viaIp, err := netip.ParseAddr(via) + if err != nil { + return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err) + } + + gateways = routing.Gateways{routing.NewGateway(viaIp, 1)} + + case []any: + gateways = make(routing.Gateways, len(via)) + for ig, v := range via { + gatewayMap, ok := v.(map[string]any) + if !ok { + return nil, fmt.Errorf("entry %v in tun.unsafe_routes[%v].via is invalid", i+1, ig+1) + } + + rGateway, ok := gatewayMap["gateway"] + if !ok { + return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not present", i+1, ig+1) + } + + parsedGateway, ok := rGateway.(string) + if !ok { + return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not a string", i+1, ig+1) + } + + gatewayIp, err := netip.ParseAddr(parsedGateway) + if err != nil { + return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] failed to parse address: %v", i+1, ig+1, err) + } + + rGatewayWeight, ok := gatewayMap["weight"] + if !ok { + rGatewayWeight = 1 + } + + gatewayWeight, ok := rGatewayWeight.(int) + if !ok { + _, err = strconv.ParseInt(rGatewayWeight.(string), 10, 32) + if err != nil { + return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not an integer", i+1, ig+1) + } + } + + if gatewayWeight < 1 || gatewayWeight > math.MaxInt32 { + return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not in range (1-%d) : %v", i+1, ig+1, math.MaxInt32, gatewayWeight) + } + + gateways[ig] = routing.NewGateway(gatewayIp, gatewayWeight) + + } + + default: + return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string or list of gateways: found %T", i+1, rVia) } rRoute, ok := m["route"] @@ -212,8 +268,6 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1) } - viaVpnIp := iputil.Ip2VpnIp(nVia) - install := true rInstall, ok := m["install"] if ok { @@ -224,24 +278,26 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { } r := Route{ - Via: &viaVpnIp, + Via: gateways, MTU: mtu, Metric: metric, Install: install, } - _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) + r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute)) if err != nil { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err) } - if ipWithin(network, r.Cidr) { - return nil, fmt.Errorf( - "entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v", - i+1, - r.Cidr.String(), - network.String(), - ) + for _, network := range networks { + if network.Contains(r.Cidr.Addr()) { + return nil, fmt.Errorf( + "entry %v.route in tun.unsafe_routes is contained within the configured vpn networks; route: %v, network: %v", + i+1, + r.Cidr.String(), + network.String(), + ) + } } routes[i] = r diff --git a/overlay/route_test.go b/overlay/route_test.go index 46fb87c..9a959a5 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -2,92 +2,100 @@ package overlay import ( "fmt" - "net" + "net/netip" "testing" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_parseRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - _, n, _ := net.ParseCIDR("10.0.0.0/24") + n, err := netip.ParsePrefix("10.0.0.0/24") + require.NoError(t, err) // test no routes config - routes, err := parseRoutes(c, n) - assert.Nil(t, err) - assert.Len(t, routes, 0) + routes, err := parseRoutes(c, []netip.Prefix{n}) + require.NoError(t, err) + assert.Empty(t, routes) // not an array - c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} - routes, err = parseRoutes(c, n) + c.Settings["tun"] = map[string]any{"routes": "hi"} + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "tun.routes is not an array") + require.EqualError(t, err, "tun.routes is not an array") // no routes - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} - routes, err = parseRoutes(c, n) - assert.Nil(t, err) - assert.Len(t, routes, 0) + c.Settings["tun"] = map[string]any{"routes": []any{}} + routes, err = parseRoutes(c, []netip.Prefix{n}) + require.NoError(t, err) + assert.Empty(t, routes) // weird route - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} - routes, err = parseRoutes(c, n) + c.Settings["tun"] = map[string]any{"routes": []any{"asdf"}} + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1 in tun.routes is invalid") + require.EqualError(t, err, "entry 1 in tun.routes is invalid") // no mtu - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} - routes, err = parseRoutes(c, n) + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{}}} + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present") + require.EqualError(t, err, "entry 1.mtu in tun.routes is not present") // bad mtu - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} - routes, err = parseRoutes(c, n) + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "nope"}}} + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") + require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} - routes, err = parseRoutes(c, n) + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "499"}}} + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") + require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") // missing route - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} - routes, err = parseRoutes(c, n) + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500"}}} + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not present") + require.EqualError(t, err, "entry 1.route in tun.routes is not present") // unparsable route - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} - routes, err = parseRoutes(c, n) + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "nope"}}} + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: invalid CIDR address: nope") + require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // below network range - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} - routes, err = parseRoutes(c, n) + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "1.0.0.0/8"}}} + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24") + require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") // above network range - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} - routes, err = parseRoutes(c, n) + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "10.0.1.0/24"}}} + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24") + require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") + + // Not in multiple ranges + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "192.0.0.0/24"}}} + routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") // happy case - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{ - map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"}, - map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, + c.Settings["tun"] = map[string]any{"routes": []any{ + map[string]any{"mtu": "9000", "route": "10.0.0.0/29"}, + map[string]any{"mtu": "8000", "route": "10.0.0.1/32"}, }} - routes, err = parseRoutes(c, n) - assert.Nil(t, err) + routes, err = parseRoutes(c, []netip.Prefix{n}) + require.NoError(t, err) assert.Len(t, routes, 2) tested := 0 @@ -112,116 +120,141 @@ func Test_parseRoutes(t *testing.T) { func Test_parseUnsafeRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - _, n, _ := net.ParseCIDR("10.0.0.0/24") + n, err := netip.ParsePrefix("10.0.0.0/24") + require.NoError(t, err) // test no routes config - routes, err := parseUnsafeRoutes(c, n) - assert.Nil(t, err) - assert.Len(t, routes, 0) + routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) + require.NoError(t, err) + assert.Empty(t, routes) // not an array - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} - routes, err = parseUnsafeRoutes(c, n) + c.Settings["tun"] = map[string]any{"unsafe_routes": "hi"} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "tun.unsafe_routes is not an array") + require.EqualError(t, err, "tun.unsafe_routes is not an array") // no routes - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} - routes, err = parseUnsafeRoutes(c, n) - assert.Nil(t, err) - assert.Len(t, routes, 0) + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + require.NoError(t, err) + assert.Empty(t, routes) // weird route - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} - routes, err = parseUnsafeRoutes(c, n) + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{"asdf"}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") + require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") // no via - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}} - routes, err = parseUnsafeRoutes(c, n) + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") + require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") // invalid via - for _, invalidValue := range []interface{}{ + for _, invalidValue := range []any{ 127, false, nil, 1.0, []string{"1", "2"}, } { - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} - routes, err = parseUnsafeRoutes(c, n) + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": invalidValue}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) + require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string or list of gateways: found %T", invalidValue)) } - // unparsable via - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + // Unparsable list of via + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": []string{"1", "2"}}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope") + require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string") + + // unparsable via + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": "nope"}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") + + // unparsable gateway + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "1"}}}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] failed to parse address: ParseAddr(\"1\"): unable to parse IP") + + // missing gateway element + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"weight": "1"}}}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] is not present") + + // unparsable weight element + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "10.0.0.1", "weight": "a"}}}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer") // missing route - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} - routes, err = parseUnsafeRoutes(c, n) + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500"}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") + require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") // unparsable route - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope") + require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // within network range - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} - routes, err = parseUnsafeRoutes(c, n) + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24") + require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24") // below network range - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} - routes, err = parseUnsafeRoutes(c, n) + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) - assert.Nil(t, err) + require.NoError(t, err) // above network range - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} - routes, err = parseUnsafeRoutes(c, n) + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) - assert.Nil(t, err) + require.NoError(t, err) // no mtu - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} - routes, err = parseUnsafeRoutes(c, n) + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) assert.Equal(t, 0, routes[0].MTU) // bad mtu - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "nope"}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") + require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}} - routes, err = parseUnsafeRoutes(c, n) + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "499"}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") + require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") // bad install - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") + require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") // happy case - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"}, - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0}, - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{ + map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"}, + map[string]any{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0}, + map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, + map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} - routes, err = parseUnsafeRoutes(c, n) - assert.Nil(t, err) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + require.NoError(t, err) assert.Len(t, routes, 4) tested := 0 @@ -252,29 +285,120 @@ func Test_parseUnsafeRoutes(t *testing.T) { func Test_makeRouteTree(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - _, n, _ := net.ParseCIDR("10.0.0.0/24") + n, err := netip.ParsePrefix("10.0.0.0/24") + require.NoError(t, err) - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ - map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, - map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"}, + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{ + map[string]any{"via": "192.168.0.1", "route": "1.0.0.0/28"}, + map[string]any{"via": "192.168.0.2", "route": "1.0.0.1/32"}, }} - routes, err := parseUnsafeRoutes(c, n) - assert.NoError(t, err) + routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) + require.NoError(t, err) assert.Len(t, routes, 2) routeTree, err := makeRouteTree(l, routes, true) - assert.NoError(t, err) + require.NoError(t, err) - ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2")) - ok, r := routeTree.MostSpecificContains(ip) + ip, err := netip.ParseAddr("1.0.0.2") + require.NoError(t, err) + r, ok := routeTree.Lookup(ip) assert.True(t, ok) - assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r) - ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1")) - ok, r = routeTree.MostSpecificContains(ip) + nip, err := netip.ParseAddr("192.168.0.1") + require.NoError(t, err) + assert.Equal(t, nip, r[0].Addr()) + + ip, err = netip.ParseAddr("1.0.0.1") + require.NoError(t, err) + r, ok = routeTree.Lookup(ip) assert.True(t, ok) - assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r) - ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1")) - ok, r = routeTree.MostSpecificContains(ip) + nip, err = netip.ParseAddr("192.168.0.2") + require.NoError(t, err) + assert.Equal(t, nip, r[0].Addr()) + + ip, err = netip.ParseAddr("1.1.0.1") + require.NoError(t, err) + r, ok = routeTree.Lookup(ip) assert.False(t, ok) } + +func Test_makeMultipathUnsafeRouteTree(t *testing.T) { + l := test.NewLogger() + c := config.NewC(l) + n, err := netip.ParsePrefix("10.0.0.0/24") + require.NoError(t, err) + + c.Settings["tun"] = map[string]any{ + "unsafe_routes": []any{ + map[string]any{ + "route": "192.168.86.0/24", + "via": "192.168.100.10", + }, + map[string]any{ + "route": "192.168.87.0/24", + "via": []any{ + map[string]any{ + "gateway": "10.0.0.1", + }, + map[string]any{ + "gateway": "10.0.0.2", + }, + map[string]any{ + "gateway": "10.0.0.3", + }, + }, + }, + map[string]any{ + "route": "192.168.89.0/24", + "via": []any{ + map[string]any{ + "gateway": "10.0.0.1", + "weight": 10, + }, + map[string]any{ + "gateway": "10.0.0.2", + "weight": 5, + }, + }, + }, + }, + } + + routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) + require.NoError(t, err) + assert.Len(t, routes, 3) + routeTree, err := makeRouteTree(l, routes, true) + require.NoError(t, err) + + ip, err := netip.ParseAddr("192.168.86.1") + require.NoError(t, err) + r, ok := routeTree.Lookup(ip) + assert.True(t, ok) + + nip, err := netip.ParseAddr("192.168.100.10") + require.NoError(t, err) + assert.Equal(t, nip, r[0].Addr()) + + ip, err = netip.ParseAddr("192.168.87.1") + require.NoError(t, err) + r, ok = routeTree.Lookup(ip) + assert.True(t, ok) + + expectedGateways := routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 1), + routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 1), + routing.NewGateway(netip.MustParseAddr("10.0.0.3"), 1)} + + routing.CalculateBucketsForGateways(expectedGateways) + assert.ElementsMatch(t, expectedGateways, r) + + ip, err = netip.ParseAddr("192.168.89.1") + require.NoError(t, err) + r, ok = routeTree.Lookup(ip) + assert.True(t, ok) + + expectedGateways = routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 10), + routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 5)} + + routing.CalculateBucketsForGateways(expectedGateways) + assert.ElementsMatch(t, expectedGateways, r) +} diff --git a/overlay/tun.go b/overlay/tun.go index cedd7fe..4a6377d 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -1,7 +1,7 @@ package overlay import ( - "net" + "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -11,36 +11,36 @@ import ( const DefaultMTU = 1300 // TODO: We may be able to remove routines -type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) +type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): - tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) + tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) return tun, nil default: - return newTun(c, l, tunCidr, routines > 1) + return newTun(c, l, vpnNetworks, routines > 1) } } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { - return newTunFromFd(c, l, *fd, tunCidr) + return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { + return newTunFromFd(c, l, *fd, vpnNetworks) } } -func getAllRoutesFromConfig(c *config.C, cidr *net.IPNet, initial bool) (bool, []Route, error) { +func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) { if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { return false, nil, nil } - routes, err := parseRoutes(c, cidr) + routes, err := parseRoutes(c, vpnNetworks) if err != nil { return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err) } - unsafeRoutes, err := parseUnsafeRoutes(c, cidr) + unsafeRoutes, err := parseUnsafeRoutes(c, vpnNetworks) if err != nil { return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index c15827f..df1ed8d 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -6,27 +6,27 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "sync/atomic" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) type tun struct { io.ReadWriteCloser - fd int - cidr *net.IPNet - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] - l *logrus.Logger + fd int + vpnNetworks []netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -34,7 +34,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) t := &tun{ ReadWriteCloser: file, fd: deviceFd, - cidr: cidr, + vpnNetworks: vpnNetworks, l: l, } @@ -53,12 +53,12 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { + r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -67,7 +67,7 @@ func (t tun) Activate() error { } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -87,8 +87,8 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) Cidr() *net.IPNet { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 1c63828..7f6ba4f 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -8,15 +8,16 @@ import ( "fmt" "io" "net" + "net/netip" "os" "sync/atomic" "syscall" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" "golang.org/x/sys/unix" @@ -24,56 +25,62 @@ import ( type tun struct { io.ReadWriteCloser - Device string - cidr *net.IPNet - DefaultMTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] - linkAddr *netroute.LinkAddr - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + DefaultMTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + linkAddr *netroute.LinkAddr + l *logrus.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte } -type sockaddrCtl struct { - scLen uint8 - scFamily uint8 - ssSysaddr uint16 - scID uint32 - scUnit uint32 - scReserved [5]uint32 -} - type ifReq struct { - Name [16]byte + Name [unix.IFNAMSIZ]byte Flags uint16 pad [8]byte } -var sockaddrCtlSize uintptr = 32 - const ( - _SYSPROTO_CONTROL = 2 //define SYSPROTO_CONTROL 2 /* kernel control protocol */ - _AF_SYS_CONTROL = 2 //#define AF_SYS_CONTROL 2 /* corresponding sub address type */ - _PF_SYSTEM = unix.AF_SYSTEM //#define PF_SYSTEM AF_SYSTEM - _CTLIOCGINFO = 3227799043 //#define CTLIOCGINFO _IOWR('N', 3, struct ctl_info) - utunControlName = "com.apple.net.utun_control" + _SIOCAIFADDR_IN6 = 2155899162 + _UTUN_OPT_IFNAME = 2 + _IN6_IFF_NODAD = 0x0020 + _IN6_IFF_SECURED = 0x0400 + utunControlName = "com.apple.net.utun_control" ) -type ifreqAddr struct { - Name [16]byte - Addr unix.RawSockaddrInet4 - pad [8]byte -} - type ifreqMTU struct { Name [16]byte MTU int32 pad [8]byte } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { +type addrLifetime struct { + Expire float64 + Preferred float64 + Vltime uint32 + Pltime uint32 +} + +type ifreqAlias4 struct { + Name [unix.IFNAMSIZ]byte + Addr unix.RawSockaddrInet4 + DstAddr unix.RawSockaddrInet4 + MaskAddr unix.RawSockaddrInet4 +} + +type ifreqAlias6 struct { + Name [unix.IFNAMSIZ]byte + Addr unix.RawSockaddrInet6 + DstAddr unix.RawSockaddrInet6 + PrefixMask unix.RawSockaddrInet6 + Flags uint32 + Lifetime addrLifetime +} + +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { @@ -86,66 +93,41 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error } } - fd, err := unix.Socket(_PF_SYSTEM, unix.SOCK_DGRAM, _SYSPROTO_CONTROL) + fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL) if err != nil { return nil, fmt.Errorf("system socket: %v", err) } - var ctlInfo = &struct { - ctlID uint32 - ctlName [96]byte - }{} + var ctlInfo = &unix.CtlInfo{} + copy(ctlInfo.Name[:], utunControlName) - copy(ctlInfo.ctlName[:], utunControlName) - - err = ioctl(uintptr(fd), uintptr(_CTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo))) + err = unix.IoctlCtlInfo(fd, ctlInfo) if err != nil { return nil, fmt.Errorf("CTLIOCGINFO: %v", err) } - sc := sockaddrCtl{ - scLen: uint8(sockaddrCtlSize), - scFamily: unix.AF_SYSTEM, - ssSysaddr: _AF_SYS_CONTROL, - scID: ctlInfo.ctlID, - scUnit: uint32(ifIndex) + 1, + err = unix.Connect(fd, &unix.SockaddrCtl{ + ID: ctlInfo.Id, + Unit: uint32(ifIndex) + 1, + }) + if err != nil { + return nil, fmt.Errorf("SYS_CONNECT: %v", err) } - _, _, errno := unix.RawSyscall( - unix.SYS_CONNECT, - uintptr(fd), - uintptr(unsafe.Pointer(&sc)), - sockaddrCtlSize, - ) - if errno != 0 { - return nil, fmt.Errorf("SYS_CONNECT: %v", errno) + name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME) + if err != nil { + return nil, fmt.Errorf("failed to retrieve tun name: %w", err) } - var ifName struct { - name [16]byte - } - ifNameSize := uintptr(len(ifName.name)) - _, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd), - 2, // SYSPROTO_CONTROL - 2, // UTUN_OPT_IFNAME - uintptr(unsafe.Pointer(&ifName)), - uintptr(unsafe.Pointer(&ifNameSize)), 0) - if errno != 0 { - return nil, fmt.Errorf("SYS_GETSOCKOPT: %v", errno) - } - name = string(ifName.name[:ifNameSize-1]) - - err = syscall.SetNonblock(fd, true) + err = unix.SetNonblock(fd, true) if err != nil { return nil, fmt.Errorf("SetNonblock: %v", err) } - file := os.NewFile(uintptr(fd), "") - t := &tun{ - ReadWriteCloser: file, + ReadWriteCloser: os.NewFile(uintptr(fd), ""), Device: name, - cidr: cidr, + vpnNetworks: vpnNetworks, DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -172,7 +154,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } @@ -186,11 +168,6 @@ func (t *tun) Close() error { func (t *tun) Activate() error { devName := t.deviceBytes() - var addr, mask [4]byte - - copy(addr[:], t.cidr.IP.To4()) - copy(mask[:], t.cidr.Mask) - s, err := unix.Socket( unix.AF_INET, unix.SOCK_DGRAM, @@ -203,66 +180,18 @@ func (t *tun) Activate() error { fd := uintptr(s) - ifra := ifreqAddr{ - Name: devName, - Addr: unix.RawSockaddrInet4{ - Family: unix.AF_INET, - Addr: addr, - }, - } - - // Set the device ip address - if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun address: %s", err) - } - - // Set the device network - ifra.Addr.Addr = mask - if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun netmask: %s", err) - } - - // Set the device name - ifrf := ifReq{Name: devName} - if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { - return fmt.Errorf("failed to set tun device name: %s", err) - } - // Set the MTU on the device ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)} if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { return fmt.Errorf("failed to set tun mtu: %v", err) } - /* - // Set the transmit queue length - ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} - if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { - // If we can't set the queue length nebula will still work but it may lead to packet loss - l.WithError(err).Error("Failed to set tun tx queue length") - } - */ - - // Bring up the interface - ifrf.Flags = ifrf.Flags | unix.IFF_UP - if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { - return fmt.Errorf("failed to bring the tun device up: %s", err) + // Get the device flags + ifrf := ifReq{Name: devName} + if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { + return fmt.Errorf("failed to get tun flags: %s", err) } - routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) - if err != nil { - return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) - } - defer func() { - unix.Shutdown(routeSock, unix.SHUT_RDWR) - err := unix.Close(routeSock) - if err != nil { - t.l.WithError(err).Error("failed to close AF_ROUTE socket") - } - }() - - routeAddr := &netroute.Inet4Addr{} - maskAddr := &netroute.Inet4Addr{} linkAddr, err := getLinkAddr(t.Device) if err != nil { return err @@ -272,14 +201,18 @@ func (t *tun) Activate() error { } t.linkAddr = linkAddr - copy(routeAddr.IP[:], addr[:]) - copy(maskAddr.IP[:], mask[:]) - err = addRoute(routeSock, routeAddr, maskAddr, linkAddr) - if err != nil { - if errors.Is(err, unix.EEXIST) { - err = fmt.Errorf("unable to add tun route, identical route already exists: %s", t.cidr) + for _, network := range t.vpnNetworks { + if network.Addr().Is4() { + err = t.activate4(network) + if err != nil { + return err + } + } else { + err = t.activate6(network) + if err != nil { + return err + } } - return err } // Run the interface @@ -292,8 +225,89 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) activate4(network netip.Prefix) error { + s, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + unix.IPPROTO_IP, + ) + if err != nil { + return err + } + defer unix.Close(s) + + ifr := ifreqAlias4{ + Name: t.deviceBytes(), + Addr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: network.Addr().As4(), + }, + DstAddr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: network.Addr().As4(), + }, + MaskAddr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: prefixToMask(network).As4(), + }, + } + + if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil { + return fmt.Errorf("failed to set tun v4 address: %s", err) + } + + err = addRoute(network, t.linkAddr) + if err != nil { + return err + } + + return nil +} + +func (t *tun) activate6(network netip.Prefix) error { + s, err := unix.Socket( + unix.AF_INET6, + unix.SOCK_DGRAM, + unix.IPPROTO_IP, + ) + if err != nil { + return err + } + defer unix.Close(s) + + ifr := ifreqAlias6{ + Name: t.deviceBytes(), + Addr: unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: network.Addr().As16(), + }, + PrefixMask: unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: prefixToMask(network).As16(), + }, + Lifetime: addrLifetime{ + // never expires + Vltime: 0xffffffff, + Pltime: 0xffffffff, + }, + //TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address? + Flags: _IN6_IFF_NODAD, + } + + if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil { + return fmt.Errorf("failed to set tun address: %s", err) + } + + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -329,17 +343,16 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - ok, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { + r, ok := t.routeTree.Load().Lookup(ip) if ok { return r } - - return 0 + return routing.Gateways{} } // Get the LinkAddr for the interface of the given name -// TODO: Is there an easier way to fetch this when we create the interface? +// Is there an easier way to fetch this when we create the interface? // Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers. func getLinkAddr(name string) (*netroute.LinkAddr, error) { rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0) @@ -367,38 +380,21 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) { } func (t *tun) addRoutes(logErrors bool) error { - routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) - if err != nil { - return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) - } - - defer func() { - unix.Shutdown(routeSock, unix.SHUT_RDWR) - err := unix.Close(routeSock) - if err != nil { - t.l.WithError(err).Error("failed to close AF_ROUTE socket") - } - }() - - routeAddr := &netroute.Inet4Addr{} - maskAddr := &netroute.Inet4Addr{} routes := *t.Routes.Load() + for _, r := range routes { - if r.Via == nil || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - copy(routeAddr.IP[:], r.Cidr.IP.To4()) - copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4()) - - err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr) + err := addRoute(r.Cidr, t.linkAddr) if err != nil { if errors.Is(err, unix.EEXIST) { t.l.WithField("route", r.Cidr). Warnf("unable to add unsafe_route, identical route already exists") } else { - retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { @@ -414,31 +410,12 @@ func (t *tun) addRoutes(logErrors bool) error { } func (t *tun) removeRoutes(routes []Route) error { - routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) - if err != nil { - return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) - } - - defer func() { - unix.Shutdown(routeSock, unix.SHUT_RDWR) - err := unix.Close(routeSock) - if err != nil { - t.l.WithError(err).Error("failed to close AF_ROUTE socket") - } - }() - - routeAddr := &netroute.Inet4Addr{} - maskAddr := &netroute.Inet4Addr{} - for _, r := range routes { if !r.Install { continue } - copy(routeAddr.IP[:], r.Cidr.IP.To4()) - copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4()) - - err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr) + err := delRoute(r.Cidr, t.linkAddr) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { @@ -448,23 +425,39 @@ func (t *tun) removeRoutes(routes []Route) error { return nil } -func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { - r := netroute.RouteMessage{ +func addRoute(prefix netip.Prefix, gateway netroute.Addr) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := &netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_ADD, Flags: unix.RTF_UP, Seq: 1, - Addrs: []netroute.Addr{ - unix.RTAX_DST: addr, - unix.RTAX_GATEWAY: link, - unix.RTAX_NETMASK: mask, - }, } - data, err := r.Marshal() + if prefix.Addr().Is4() { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: gateway, + } + } else { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: gateway, + } + } + + data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } + _, err = unix.Write(sock, data[:]) if err != nil { return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) @@ -473,19 +466,34 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) return nil } -func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { - r := netroute.RouteMessage{ +func delRoute(prefix netip.Prefix, gateway netroute.Addr) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_DELETE, Seq: 1, - Addrs: []netroute.Addr{ - unix.RTAX_DST: addr, - unix.RTAX_GATEWAY: link, - unix.RTAX_NETMASK: mask, - }, } - data, err := r.Marshal() + if prefix.Addr().Is4() { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: gateway, + } + } else { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: gateway, + } + } + + data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } @@ -498,7 +506,6 @@ func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) } func (t *tun) Read(to []byte) (int, error) { - buf := make([]byte, len(to)+4) n, err := t.ReadWriteCloser.Read(buf) @@ -536,8 +543,8 @@ func (t *tun) Write(from []byte) (int, error) { return n - 4, err } -func (t *tun) Cidr() *net.IPNet { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { @@ -547,3 +554,13 @@ func (t *tun) Name() string { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } + +func prefixToMask(prefix netip.Prefix) netip.Addr { + pLen := 128 + if prefix.Addr().Is4() { + pLen = 32 + } + + addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen)) + return addr +} diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index e1e4ede..131879d 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -3,17 +3,18 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "strings" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" ) type disabledTun struct { - read chan []byte - cidr *net.IPNet + read chan []byte + vpnNetworks []netip.Prefix // Track these metrics since we don't have the tun device to do it for us tx metrics.Counter @@ -21,11 +22,11 @@ type disabledTun struct { l *logrus.Logger } -func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { +func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { tun := &disabledTun{ - cidr: cidr, - read: make(chan []byte, queueLen), - l: l, + vpnNetworks: vpnNetworks, + read: make(chan []byte, queueLen), + l: l, } if metricsEnabled { @@ -43,12 +44,12 @@ func (*disabledTun) Activate() error { return nil } -func (*disabledTun) RouteFor(iputil.VpnIp) iputil.VpnIp { - return 0 +func (*disabledTun) RoutesFor(addr netip.Addr) routing.Gateways { + return routing.Gateways{} } -func (t *disabledTun) Cidr() *net.IPNet { - return t.cidr +func (t *disabledTun) Networks() []netip.Prefix { + return t.vpnNetworks } func (*disabledTun) Name() string { diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 3b1b80f..2a89cbc 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -9,7 +9,7 @@ import ( "fmt" "io" "io/fs" - "net" + "net/netip" "os" "os/exec" "strconv" @@ -17,10 +17,10 @@ import ( "syscall" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -47,12 +47,12 @@ type ifreqDestroy struct { } type tun struct { - Device string - cidr *net.IPNet - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger io.ReadWriteCloser } @@ -79,11 +79,11 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var file *os.File var err error @@ -151,7 +151,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -171,16 +171,16 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error return t, nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device) + cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -196,8 +196,18 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -233,13 +243,13 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *tun) Cidr() *net.IPNet { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { @@ -253,7 +263,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } @@ -261,7 +271,7 @@ func (t *tun) addRoutes(logErrors bool) error { cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { - retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index ba15d66..e51e112 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -7,35 +7,35 @@ import ( "errors" "fmt" "io" - "net" + "net/netip" "os" "sync" "sync/atomic" "syscall" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) type tun struct { io.ReadWriteCloser - cidr *net.IPNet - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] - l *logrus.Logger + vpnNetworks []netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ - cidr: cidr, + vpnNetworks: vpnNetworks, ReadWriteCloser: &tunReadCloser{f: file}, l: l, } @@ -60,7 +60,7 @@ func (t *tun) Activate() error { } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -80,8 +80,8 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { + r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -143,8 +143,8 @@ func (tr *tunReadCloser) Close() error { return tr.f.Close() } -func (t *tun) Cidr() *net.IPNet { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 2f06951..7d19c85 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,19 +4,20 @@ package overlay import ( - "bytes" "fmt" "io" "net" + "net/netip" "os" "strings" "sync/atomic" + "time" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" @@ -26,7 +27,7 @@ type tun struct { io.ReadWriteCloser fd int Device string - cidr *net.IPNet + vpnNetworks []netip.Prefix MaxMTU int DefaultMTU int TXQueueLen int @@ -34,25 +35,23 @@ type tun struct { ioctlFd uintptr Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeChan chan struct{} useSystemRoutes bool l *logrus.Logger } +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks +} + type ifReq struct { Name [16]byte Flags uint16 pad [8]byte } -type ifreqAddr struct { - Name [16]byte - Addr unix.RawSockaddrInet4 - pad [8]byte -} - type ifreqMTU struct { Name [16]byte MTU int32 @@ -65,10 +64,10 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, cidr) + t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { return nil, err } @@ -78,7 +77,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) return t, nil } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -113,7 +112,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*t name := strings.Trim(string(req.Name[:]), "\x00") file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, cidr) + t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { return nil, err } @@ -123,11 +122,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*t return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr *net.IPNet) (*tun, error) { +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), - cidr: cidr, + vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), l: l, @@ -149,7 +148,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr *net.IPNet } func (t *tun) reload(c *config.C, initial bool) error { - routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -191,11 +190,13 @@ func (t *tun) reload(c *config.C, initial bool) error { } if oldDefaultMTU != newDefaultMTU { - err := t.setDefaultRoute() - if err != nil { - t.l.Warn(err) - } else { - t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + for i := range t.vpnNetworks { + err := t.setDefaultRoute(t.vpnNetworks[i]) + if err != nil { + t.l.Warn(err) + } else { + t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + } } } @@ -231,17 +232,17 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return file, nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { + r, _ := t.routeTree.Load().Lookup(ip) return r } func (t *tun) Write(b []byte) (int, error) { var nn int - max := len(b) + maximum := len(b) for { - n, err := unix.Write(t.fd, b[nn:max]) + n, err := unix.Write(t.fd, b[nn:maximum]) if n > 0 { nn += n } @@ -266,6 +267,58 @@ func (t *tun) deviceBytes() (o [16]byte) { return } +func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool { + for i := range al { + if al[i].Equal(x) { + return true + } + } + return false +} + +// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there +func (t *tun) addIPs(link netlink.Link) error { + newAddrs := make([]*netlink.Addr, len(t.vpnNetworks)) + for i := range t.vpnNetworks { + newAddrs[i] = &netlink.Addr{ + IPNet: &net.IPNet{ + IP: t.vpnNetworks[i].Addr().AsSlice(), + Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()), + }, + Label: t.vpnNetworks[i].Addr().Zone(), + } + } + + //add all new addresses + for i := range newAddrs { + //TODO: CERT-V2 do we want to stack errors and try as many ops as possible? + //AddrReplace still adds new IPs, but if their properties change it will change them as well + if err := netlink.AddrReplace(link, newAddrs[i]); err != nil { + return err + } + } + + //iterate over remainder, remove whoever shouldn't be there + al, err := netlink.AddrList(link, netlink.FAMILY_ALL) + if err != nil { + return fmt.Errorf("failed to get tun address list: %s", err) + } + + for i := range al { + if hasNetlinkAddr(newAddrs, al[i]) { + continue + } + err = netlink.AddrDel(link, &al[i]) + if err != nil { + t.l.WithError(err).Error("failed to remove address from tun address list") + } else { + t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)") + } + } + + return nil +} + func (t *tun) Activate() error { devName := t.deviceBytes() @@ -273,13 +326,8 @@ func (t *tun) Activate() error { t.watchRoutes() } - var addr, mask [4]byte - - copy(addr[:], t.cidr.IP.To4()) - copy(mask[:], t.cidr.Mask) - s, err := unix.Socket( - unix.AF_INET, + unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine unix.SOCK_DGRAM, unix.IPPROTO_IP, ) @@ -288,31 +336,19 @@ func (t *tun) Activate() error { } t.ioctlFd = uintptr(s) - ifra := ifreqAddr{ - Name: devName, - Addr: unix.RawSockaddrInet4{ - Family: unix.AF_INET, - Addr: addr, - }, - } - - // Set the device ip address - if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun address: %s", err) - } - - // Set the device network - ifra.Addr.Addr = mask - if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun netmask: %s", err) - } - // Set the device name ifrf := ifReq{Name: devName} if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to set tun device name: %s", err) } + link, err := netlink.LinkByName(t.Device) + if err != nil { + return fmt.Errorf("failed to get tun device link: %s", err) + } + + t.deviceIndex = link.Attrs().Index + // Setup our default MTU t.setMTU() @@ -323,20 +359,21 @@ func (t *tun) Activate() error { t.l.WithError(err).Error("Failed to set tun tx queue length") } + if err = t.addIPs(link); err != nil { + return err + } + // Bring up the interface ifrf.Flags = ifrf.Flags | unix.IFF_UP if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to bring the tun device up: %s", err) } - link, err := netlink.LinkByName(t.Device) - if err != nil { - return fmt.Errorf("failed to get tun device link: %s", err) - } - t.deviceIndex = link.Attrs().Index - - if err = t.setDefaultRoute(); err != nil { - return err + //set route MTU + for i := range t.vpnNetworks { + if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil { + return fmt.Errorf("failed to set default route MTU: %w", err) + } } // Set the routes @@ -362,23 +399,39 @@ func (t *tun) setMTU() { } } -func (t *tun) setDefaultRoute() error { - // Default route - dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask} +func (t *tun) setDefaultRoute(cidr netip.Prefix) error { + dr := &net.IPNet{ + IP: cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()), + } + nr := netlink.Route{ LinkIndex: t.deviceIndex, Dst: dr, MTU: t.DefaultMTU, AdvMSS: t.advMSS(Route{}), Scope: unix.RT_SCOPE_LINK, - Src: t.cidr.IP, + Src: net.IP(cidr.Addr().AsSlice()), Protocol: unix.RTPROT_KERNEL, Table: unix.RT_TABLE_MAIN, Type: unix.RTN_UNICAST, } err := netlink.RouteReplace(&nr) if err != nil { - return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err) + t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying") + //retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument` + for i := 0; i < 2; i++ { + time.Sleep(100 * time.Millisecond) + err = netlink.RouteReplace(&nr) + if err == nil { + break + } else { + t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying") + } + } + if err != nil { + return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err) + } } return nil @@ -392,9 +445,14 @@ func (t *tun) addRoutes(logErrors bool) error { continue } + dr := &net.IPNet{ + IP: r.Cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()), + } + nr := netlink.Route{ LinkIndex: t.deviceIndex, - Dst: r.Cidr, + Dst: dr, MTU: r.MTU, AdvMSS: t.advMSS(r), Scope: unix.RT_SCOPE_LINK, @@ -406,7 +464,7 @@ func (t *tun) addRoutes(logErrors bool) error { err := netlink.RouteReplace(&nr) if err != nil { - retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { @@ -426,9 +484,14 @@ func (t *tun) removeRoutes(routes []Route) { continue } + dr := &net.IPNet{ + IP: r.Cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()), + } + nr := netlink.Route{ LinkIndex: t.deviceIndex, - Dst: r.Cidr, + Dst: dr, MTU: r.MTU, AdvMSS: t.advMSS(r), Scope: unix.RT_SCOPE_LINK, @@ -447,10 +510,6 @@ func (t *tun) removeRoutes(routes []Route) { } } -func (t *tun) Cidr() *net.IPNet { - return t.cidr -} - func (t *tun) Name() string { return t.Device } @@ -492,47 +551,98 @@ func (t *tun) watchRoutes() { }() } -func (t *tun) updateRoutes(r netlink.RouteUpdate) { - if r.Gw == nil { - // Not a gateway route, ignore - t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route") - return - } - - if !t.cidr.Contains(r.Gw) { - // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, not in our network") - return - } - - if x := r.Dst.IP.To4(); x == nil { - // Nebula only handles ipv4 on the overlay currently - t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4") - return - } - - newTree := cidr.NewTree4[iputil.VpnIp]() - if r.Type == unix.RTM_NEWROUTE { - for _, oldR := range t.routeTree.Load().List() { - newTree.AddCIDR(oldR.CIDR, oldR.Value) +func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool { + withinNetworks := false + for i := range t.vpnNetworks { + if t.vpnNetworks[i].Contains(gwAddr) { + withinNetworks = true + break } + } - t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route") - newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw)) + return withinNetworks +} + +func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { + + var gateways routing.Gateways + + link, err := netlink.LinkByName(t.Device) + if err != nil { + t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name") + return gateways + } + + // If this route is relevant to our interface and there is a gateway then add it + if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 { + gwAddr, ok := netip.AddrFromSlice(r.Gw) + if !ok { + t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address") + } else { + gwAddr = gwAddr.Unmap() + + if !t.isGatewayInVpnNetworks(gwAddr) { + // Gateway isn't in our overlay network, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not in our network") + } else { + gateways = append(gateways, routing.NewGateway(gwAddr, 1)) + } + } + } + + for _, p := range r.MultiPath { + // If this route is relevant to our interface and there is a gateway then add it + if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 { + gwAddr, ok := netip.AddrFromSlice(p.Gw) + if !ok { + t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address") + } else { + gwAddr = gwAddr.Unmap() + + if !t.isGatewayInVpnNetworks(gwAddr) { + // Gateway isn't in our overlay network, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not in our network") + } else { + // p.Hops+1 = weight of the route + gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1)) + } + } + } + } + + routing.CalculateBucketsForGateways(gateways) + return gateways +} + +func (t *tun) updateRoutes(r netlink.RouteUpdate) { + + gateways := t.getGatewaysFromRoute(&r.Route) + + if len(gateways) == 0 { + // No gateways relevant to our network, no routing changes required. + t.l.WithField("route", r).Debug("Ignoring route update, no gateways") + return + } + + dstAddr, ok := netip.AddrFromSlice(r.Dst.IP) + if !ok { + t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address") + return + } + + ones, _ := r.Dst.Mask.Size() + dst := netip.PrefixFrom(dstAddr, ones) + + newTree := t.routeTree.Load().Clone() + + if r.Type == unix.RTM_NEWROUTE { + t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route") + newTree.Insert(dst, gateways) } else { - gw := iputil.Ip2VpnIp(r.Gw) - for _, oldR := range t.routeTree.Load().List() { - if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw { - // This is the record to delete - t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") - continue - } - - newTree.AddCIDR(oldR.CIDR, oldR.Value) - } + t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route") + newTree.Delete(dst) } - t.routeTree.Store(newTree) } @@ -542,11 +652,11 @@ func (t *tun) Close() error { } if t.ReadWriteCloser != nil { - t.ReadWriteCloser.Close() + _ = t.ReadWriteCloser.Close() } if t.ioctlFd > 0 { - os.NewFile(t.ioctlFd, "ioctlFd").Close() + _ = os.NewFile(t.ioctlFd, "ioctlFd").Close() } return nil diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index cc0216f..5ff9b0f 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -6,7 +6,7 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "os/exec" "regexp" @@ -15,10 +15,10 @@ import ( "syscall" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -28,12 +28,12 @@ type ifreqDestroy struct { } type tun struct { - Device string - cidr *net.IPNet - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger io.ReadWriteCloser } @@ -59,13 +59,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var file *os.File var err error @@ -85,7 +85,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -105,17 +105,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error return t, nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.IP.String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -131,8 +131,18 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -168,13 +178,13 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *tun) Cidr() *net.IPNet { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { @@ -188,15 +198,15 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { - retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { @@ -214,7 +224,8 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.IP.String()) + //TODO: CERT-V2 is this right? + cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 53f57b1..67a9a5f 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -6,7 +6,7 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "os/exec" "regexp" @@ -14,20 +14,20 @@ import ( "sync/atomic" "syscall" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) type tun struct { - Device string - cidr *net.IPNet - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger io.ReadWriteCloser @@ -43,13 +43,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { deviceName := c.GetString("tun.dev", "") if deviceName == "" { return nil, fmt.Errorf("a device name in the format of tunN must be specified") @@ -67,7 +67,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -88,7 +88,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -124,10 +124,10 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) @@ -139,7 +139,7 @@ func (t *tun) Activate() error { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.IP.String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -149,23 +149,33 @@ func (t *tun) Activate() error { return t.addRoutes(false) } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { + r, _ := t.routeTree.Load().Lookup(ip) return r } func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - - cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String()) + //TODO: CERT-V2 is this right? + cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { - retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { @@ -182,8 +192,8 @@ func (t *tun) removeRoutes(routes []Route) error { if !r.Install { continue } - - cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.IP.String()) + //TODO: CERT-V2 is this right? + cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") @@ -194,8 +204,8 @@ func (t *tun) removeRoutes(routes []Route) error { return nil } -func (t *tun) Cidr() *net.IPNet { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 3833983..b6712fb 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -6,30 +6,30 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "sync/atomic" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" ) type TestTun struct { - Device string - cidr *net.IPNet - Routes []Route - routeTree *cidr.Tree4[iputil.VpnIp] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + Routes []Route + routeTree *bart.Table[routing.Gateways] + l *logrus.Logger closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, error) { - _, routes, err := getAllRoutesFromConfig(c, cidr, true) +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { + _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true) if err != nil { return nil, err } @@ -39,17 +39,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, e } return &TestTun{ - Device: c.GetString("tun.dev", ""), - cidr: cidr, - Routes: routes, - routeTree: routeTree, - l: l, - rxPackets: make(chan []byte, 10), - TxPackets: make(chan []byte, 10), + Device: c.GetString("tun.dev", ""), + vpnNetworks: vpnNetworks, + Routes: routes, + routeTree: routeTree, + l: l, + rxPackets: make(chan []byte, 10), + TxPackets: make(chan []byte, 10), }, nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } @@ -87,8 +87,8 @@ func (t *TestTun) Get(block bool) []byte { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.MostSpecificContains(ip) +func (t *TestTun) RoutesFor(ip netip.Addr) routing.Gateways { + r, _ := t.routeTree.Lookup(ip) return r } @@ -96,8 +96,8 @@ func (t *TestTun) Activate() error { return nil } -func (t *TestTun) Cidr() *net.IPNet { - return t.cidr +func (t *TestTun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *TestTun) Name() string { diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go deleted file mode 100644 index a1acd2b..0000000 --- a/overlay/tun_water_windows.go +++ /dev/null @@ -1,208 +0,0 @@ -package overlay - -import ( - "fmt" - "io" - "net" - "os/exec" - "strconv" - "sync/atomic" - - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" - "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/util" - "github.com/songgao/water" -) - -type waterTun struct { - Device string - cidr *net.IPNet - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] - l *logrus.Logger - f *net.Interface - *water.Interface -} - -func newWaterTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*waterTun, error) { - // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() - t := &waterTun{ - cidr: cidr, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, - } - - err := t.reload(c, true) - if err != nil { - return nil, err - } - - c.RegisterReloadCallback(func(c *config.C) { - err := t.reload(c, false) - if err != nil { - util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) - } - }) - - return t, nil -} - -func (t *waterTun) Activate() error { - var err error - t.Interface, err = water.New(water.Config{ - DeviceType: water.TUN, - PlatformSpecificParams: water.PlatformSpecificParams{ - ComponentID: "tap0901", - Network: t.cidr.String(), - }, - }) - if err != nil { - return fmt.Errorf("activate failed: %v", err) - } - - t.Device = t.Interface.Name() - - // TODO use syscalls instead of exec.Command - err = exec.Command( - `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address", - fmt.Sprintf("name=%s", t.Device), - "source=static", - fmt.Sprintf("addr=%s", t.cidr.IP), - fmt.Sprintf("mask=%s", net.IP(t.cidr.Mask)), - "gateway=none", - ).Run() - if err != nil { - return fmt.Errorf("failed to run 'netsh' to set address: %s", err) - } - err = exec.Command( - `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "interface", - t.Device, - fmt.Sprintf("mtu=%d", t.MTU), - ).Run() - if err != nil { - return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err) - } - - t.f, err = net.InterfaceByName(t.Device) - if err != nil { - return fmt.Errorf("failed to find interface named %s: %v", t.Device, err) - } - - err = t.addRoutes(false) - if err != nil { - return err - } - - return nil -} - -func (t *waterTun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) - if err != nil { - return err - } - - if !initial && !change { - return nil - } - - routeTree, err := makeRouteTree(t.l, routes, false) - if err != nil { - return err - } - - // Teach nebula how to handle the routes before establishing them in the system table - oldRoutes := t.Routes.Swap(&routes) - t.routeTree.Store(routeTree) - - if !initial { - // Remove first, if the system removes a wanted route hopefully it will be re-added next - t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) - - // Ensure any routes we actually want are installed - err = t.addRoutes(true) - if err != nil { - // Catch any stray logs - util.LogWithContextIfNeeded("Failed to set routes", err, t.l) - } else { - for _, r := range findRemovedRoutes(routes, *oldRoutes) { - t.l.WithField("route", r).Info("Removed route") - } - } - } - - return nil -} - -func (t *waterTun) addRoutes(logErrors bool) error { - // Path routes - routes := *t.Routes.Load() - for _, r := range routes { - if r.Via == nil || !r.Install { - // We don't allow route MTUs so only install routes with a via - continue - } - - err := exec.Command( - "C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric), - ).Run() - - if err != nil { - retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) - if logErrors { - retErr.Log(t.l) - } else { - return retErr - } - } else { - t.l.WithField("route", r).Info("Added route") - } - } - - return nil -} - -func (t *waterTun) removeRoutes(routes []Route) { - for _, r := range routes { - if !r.Install { - continue - } - - err := exec.Command( - "C:\\Windows\\System32\\route.exe", "delete", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric), - ).Run() - if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") - } else { - t.l.WithField("route", r).Info("Removed route") - } - } -} - -func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) - return r -} - -func (t *waterTun) Cidr() *net.IPNet { - return t.cidr -} - -func (t *waterTun) Name() string { - return t.Device -} - -func (t *waterTun) Close() error { - if t.Interface == nil { - return nil - } - - return t.Interface.Close() -} - -func (t *waterTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") -} diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index f85ee9c..7aac128 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -4,41 +4,273 @@ package overlay import ( + "crypto" "fmt" - "net" + "io" + "net/netip" "os" "path/filepath" "runtime" + "sync/atomic" "syscall" + "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" + "github.com/slackhq/nebula/util" + "github.com/slackhq/nebula/wintun" + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (Device, error) { +const tunGUIDLabel = "Fixed Nebula Windows GUID v1" + +type winTun struct { + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger + + tun *wintun.NativeTun +} + +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (Device, error) { - useWintun := true - if err := checkWinTunExists(); err != nil { - l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver") - useWintun = false - } - - if useWintun { - device, err := newWinTun(c, l, cidr, multiqueue) - if err != nil { - return nil, fmt.Errorf("create Wintun interface failed, %w", err) - } - return device, nil - } - - device, err := newWaterTun(c, l, cidr, multiqueue) +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { + err := checkWinTunExists() if err != nil { - return nil, fmt.Errorf("create wintap driver failed, %w", err) + return nil, fmt.Errorf("can not load the wintun driver: %w", err) } - return device, nil + + deviceName := c.GetString("tun.dev", "") + guid, err := generateGUIDByDeviceName(deviceName) + if err != nil { + return nil, fmt.Errorf("generate GUID failed: %w", err) + } + + t := &winTun{ + Device: deviceName, + vpnNetworks: vpnNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + } + + err = t.reload(c, true) + if err != nil { + return nil, err + } + + var tunDevice wintun.Device + tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) + if err != nil { + // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. + // Trying a second time resolves the issue. + l.WithError(err).Debug("Failed to create wintun device, retrying") + tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) + if err != nil { + return nil, fmt.Errorf("create TUN device failed: %w", err) + } + } + t.tun = tunDevice.(*wintun.NativeTun) + + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil +} + +func (t *winTun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + // Remove first, if the system removes a wanted route hopefully it will be re-added next + err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) + if err != nil { + util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) + } + + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // Catch any stray logs + util.LogWithContextIfNeeded("Failed to add routes", err, t.l) + } + } + + return nil +} + +func (t *winTun) Activate() error { + luid := winipcfg.LUID(t.tun.LUID()) + + err := luid.SetIPAddresses(t.vpnNetworks) + if err != nil { + return fmt.Errorf("failed to set address: %w", err) + } + + err = t.addRoutes(false) + if err != nil { + return err + } + + return nil +} + +func (t *winTun) addRoutes(logErrors bool) error { + luid := winipcfg.LUID(t.tun.LUID()) + routes := *t.Routes.Load() + foundDefault4 := false + + for _, r := range routes { + if len(r.Via) == 0 || !r.Install { + // We don't allow route MTUs so only install routes with a via + continue + } + + // Add our unsafe route + // Windows does not support multipath routes natively, so we install only a single route. + // This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally. + // In effect this provides multipath routing support to windows supporting loadbalancing and redundancy. + err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric)) + if err != nil { + retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) + if logErrors { + retErr.Log(t.l) + continue + } else { + return retErr + } + } else { + t.l.WithField("route", r).Info("Added route") + } + + if !foundDefault4 { + if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 { + foundDefault4 = true + } + } + } + + ipif, err := luid.IPInterface(windows.AF_INET) + if err != nil { + return fmt.Errorf("failed to get ip interface: %w", err) + } + + ipif.NLMTU = uint32(t.MTU) + if foundDefault4 { + ipif.UseAutomaticMetric = false + ipif.Metric = 0 + } + + if err := ipif.Set(); err != nil { + return fmt.Errorf("failed to set ip interface: %w", err) + } + return nil +} + +func (t *winTun) removeRoutes(routes []Route) error { + luid := winipcfg.LUID(t.tun.LUID()) + + for _, r := range routes { + if !r.Install { + continue + } + + // See comment on luid.AddRoute + err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr()) + if err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } + } + return nil +} + +func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways { + r, _ := t.routeTree.Load().Lookup(ip) + return r +} + +func (t *winTun) Networks() []netip.Prefix { + return t.vpnNetworks +} + +func (t *winTun) Name() string { + return t.Device +} + +func (t *winTun) Read(b []byte) (int, error) { + return t.tun.Read(b, 0) +} + +func (t *winTun) Write(b []byte) (int, error) { + return t.tun.Write(b, 0) +} + +func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { + return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") +} + +func (t *winTun) Close() error { + // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes, + // so to be certain, just remove everything before destroying. + luid := winipcfg.LUID(t.tun.LUID()) + _ = luid.FlushRoutes(windows.AF_INET) + _ = luid.FlushIPAddresses(windows.AF_INET) + + _ = luid.FlushRoutes(windows.AF_INET6) + _ = luid.FlushIPAddresses(windows.AF_INET6) + + _ = luid.FlushDNS(windows.AF_INET) + _ = luid.FlushDNS(windows.AF_INET6) + + return t.tun.Close() +} + +func generateGUIDByDeviceName(name string) (*windows.GUID, error) { + // GUID is 128 bit + hash := crypto.MD5.New() + + _, err := hash.Write([]byte(tunGUIDLabel)) + if err != nil { + return nil, err + } + + _, err = hash.Write([]byte(name)) + if err != nil { + return nil, err + } + + sum := hash.Sum(nil) + + return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil } func checkWinTunExists() error { diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go deleted file mode 100644 index 197e3a7..0000000 --- a/overlay/tun_wintun_windows.go +++ /dev/null @@ -1,278 +0,0 @@ -package overlay - -import ( - "crypto" - "fmt" - "io" - "net" - "net/netip" - "sync/atomic" - "unsafe" - - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" - "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/util" - "github.com/slackhq/nebula/wintun" - "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" -) - -const tunGUIDLabel = "Fixed Nebula Windows GUID v1" - -type winTun struct { - Device string - cidr *net.IPNet - prefix netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] - l *logrus.Logger - - tun *wintun.NativeTun -} - -func generateGUIDByDeviceName(name string) (*windows.GUID, error) { - // GUID is 128 bit - hash := crypto.MD5.New() - - _, err := hash.Write([]byte(tunGUIDLabel)) - if err != nil { - return nil, err - } - - _, err = hash.Write([]byte(name)) - if err != nil { - return nil, err - } - - sum := hash.Sum(nil) - - return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil -} - -func newWinTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*winTun, error) { - deviceName := c.GetString("tun.dev", "") - guid, err := generateGUIDByDeviceName(deviceName) - if err != nil { - return nil, fmt.Errorf("generate GUID failed: %w", err) - } - - prefix, err := iputil.ToNetIpPrefix(*cidr) - if err != nil { - return nil, err - } - - t := &winTun{ - Device: deviceName, - cidr: cidr, - prefix: prefix, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, - } - - err = t.reload(c, true) - if err != nil { - return nil, err - } - - var tunDevice wintun.Device - tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) - if err != nil { - // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. - // Trying a second time resolves the issue. - l.WithError(err).Debug("Failed to create wintun device, retrying") - tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) - if err != nil { - return nil, fmt.Errorf("create TUN device failed: %w", err) - } - } - t.tun = tunDevice.(*wintun.NativeTun) - - c.RegisterReloadCallback(func(c *config.C) { - err := t.reload(c, false) - if err != nil { - util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) - } - }) - - return t, nil -} - -func (t *winTun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) - if err != nil { - return err - } - - if !initial && !change { - return nil - } - - routeTree, err := makeRouteTree(t.l, routes, false) - if err != nil { - return err - } - - // Teach nebula how to handle the routes before establishing them in the system table - oldRoutes := t.Routes.Swap(&routes) - t.routeTree.Store(routeTree) - - if !initial { - // Remove first, if the system removes a wanted route hopefully it will be re-added next - err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) - if err != nil { - util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) - } - - // Ensure any routes we actually want are installed - err = t.addRoutes(true) - if err != nil { - // Catch any stray logs - util.LogWithContextIfNeeded("Failed to add routes", err, t.l) - } - } - - return nil -} - -func (t *winTun) Activate() error { - luid := winipcfg.LUID(t.tun.LUID()) - - err := luid.SetIPAddresses([]netip.Prefix{t.prefix}) - if err != nil { - return fmt.Errorf("failed to set address: %w", err) - } - - err = t.addRoutes(false) - if err != nil { - return err - } - - return nil -} - -func (t *winTun) addRoutes(logErrors bool) error { - luid := winipcfg.LUID(t.tun.LUID()) - routes := *t.Routes.Load() - foundDefault4 := false - - for _, r := range routes { - if r.Via == nil || !r.Install { - // We don't allow route MTUs so only install routes with a via - continue - } - - prefix, err := iputil.ToNetIpPrefix(*r.Cidr) - if err != nil { - retErr := util.NewContextualError("Failed to parse cidr to netip prefix, ignoring route", map[string]interface{}{"route": r}, err) - if logErrors { - retErr.Log(t.l) - continue - } else { - return retErr - } - } - - // Add our unsafe route - err = luid.AddRoute(prefix, r.Via.ToNetIpAddr(), uint32(r.Metric)) - if err != nil { - retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) - if logErrors { - retErr.Log(t.l) - continue - } else { - return retErr - } - } else { - t.l.WithField("route", r).Info("Added route") - } - - if !foundDefault4 { - if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 { - foundDefault4 = true - } - } - } - - ipif, err := luid.IPInterface(windows.AF_INET) - if err != nil { - return fmt.Errorf("failed to get ip interface: %w", err) - } - - ipif.NLMTU = uint32(t.MTU) - if foundDefault4 { - ipif.UseAutomaticMetric = false - ipif.Metric = 0 - } - - if err := ipif.Set(); err != nil { - return fmt.Errorf("failed to set ip interface: %w", err) - } - return nil -} - -func (t *winTun) removeRoutes(routes []Route) error { - luid := winipcfg.LUID(t.tun.LUID()) - - for _, r := range routes { - if !r.Install { - continue - } - - prefix, err := iputil.ToNetIpPrefix(*r.Cidr) - if err != nil { - t.l.WithError(err).WithField("route", r).Info("Failed to convert cidr to netip prefix") - continue - } - - err = luid.DeleteRoute(prefix, r.Via.ToNetIpAddr()) - if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") - } else { - t.l.WithField("route", r).Info("Removed route") - } - } - return nil -} - -func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) - return r -} - -func (t *winTun) Cidr() *net.IPNet { - return t.cidr -} - -func (t *winTun) Name() string { - return t.Device -} - -func (t *winTun) Read(b []byte) (int, error) { - return t.tun.Read(b, 0) -} - -func (t *winTun) Write(b []byte) (int, error) { - return t.tun.Write(b, 0) -} - -func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") -} - -func (t *winTun) Close() error { - // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes, - // so to be certain, just remove everything before destroying. - luid := winipcfg.LUID(t.tun.LUID()) - _ = luid.FlushRoutes(windows.AF_INET) - _ = luid.FlushIPAddresses(windows.AF_INET) - /* We don't support IPV6 yet - _ = luid.FlushRoutes(windows.AF_INET6) - _ = luid.FlushIPAddresses(windows.AF_INET6) - */ - _ = luid.FlushDNS(windows.AF_INET) - - return t.tun.Close() -} diff --git a/overlay/user.go b/overlay/user.go index 9d819ae..8a56d66 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -2,23 +2,23 @@ package overlay import ( "io" - "net" + "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { - return NewUserDevice(tunCidr) +func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { + return NewUserDevice(vpnNetworks) } -func NewUserDevice(tunCidr *net.IPNet) (Device, error) { +func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) { // these pipes guarantee each write/read will match 1:1 or, ow := io.Pipe() ir, iw := io.Pipe() return &UserDevice{ - tunCidr: tunCidr, + vpnNetworks: vpnNetworks, outboundReader: or, outboundWriter: ow, inboundReader: ir, @@ -27,7 +27,7 @@ func NewUserDevice(tunCidr *net.IPNet) (Device, error) { } type UserDevice struct { - tunCidr *net.IPNet + vpnNetworks []netip.Prefix outboundReader *io.PipeReader outboundWriter *io.PipeWriter @@ -39,9 +39,13 @@ type UserDevice struct { func (d *UserDevice) Activate() error { return nil } -func (d *UserDevice) Cidr() *net.IPNet { return d.tunCidr } -func (d *UserDevice) Name() string { return "faketun0" } -func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip } + +func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks } +func (d *UserDevice) Name() string { return "faketun0" } +func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways { + return routing.Gateways{routing.NewGateway(ip, 1)} +} + func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return d, nil } diff --git a/pkclient/pkclient.go b/pkclient/pkclient.go new file mode 100644 index 0000000..7061de6 --- /dev/null +++ b/pkclient/pkclient.go @@ -0,0 +1,87 @@ +package pkclient + +import ( + "crypto/ecdsa" + "crypto/x509" + "fmt" + "io" + "strconv" + + "github.com/stefanberger/go-pkcs11uri" +) + +type Client interface { + io.Closer + GetPubKey() ([]byte, error) + DeriveNoise(peerPubKey []byte) ([]byte, error) + Test() error +} + +const NoiseKeySize = 32 + +func FromUrl(pkurl string) (*PKClient, error) { + uri := pkcs11uri.New() + uri.SetAllowAnyModule(true) //todo + err := uri.Parse(pkurl) + if err != nil { + return nil, err + } + + module, err := uri.GetModule() + if err != nil { + return nil, err + } + + slotid := 0 + slot, ok := uri.GetPathAttribute("slot-id", false) + if !ok { + slotid = 0 + } else { + slotid, err = strconv.Atoi(slot) + if err != nil { + return nil, err + } + } + + pin, _ := uri.GetPIN() + id, _ := uri.GetPathAttribute("id", false) + label, _ := uri.GetPathAttribute("object", false) + + return New(module, uint(slotid), pin, id, label) +} + +func ecKeyToArray(key *ecdsa.PublicKey) []byte { + x := make([]byte, 32) + y := make([]byte, 32) + key.X.FillBytes(x) + key.Y.FillBytes(y) + return append([]byte{0x04}, append(x, y...)...) +} + +func formatPubkeyFromPublicKeyInfoAttr(d []byte) ([]byte, error) { + e, err := x509.ParsePKIXPublicKey(d) + if err != nil { + return nil, err + } + switch t := e.(type) { + case *ecdsa.PublicKey: + return ecKeyToArray(e.(*ecdsa.PublicKey)), nil + default: + return nil, fmt.Errorf("unknown public key type: %T", t) + } +} + +func (c *PKClient) Test() error { + pub, err := c.GetPubKey() + if err != nil { + return fmt.Errorf("failed to get public key: %w", err) + } + out, err := c.DeriveNoise(pub) //do an ECDH with ourselves as a quick test + if err != nil { + return err + } + if len(out) != NoiseKeySize { + return fmt.Errorf("got a key of %d bytes, expected %d", len(out), NoiseKeySize) + } + return nil +} diff --git a/pkclient/pkclient_cgo.go b/pkclient/pkclient_cgo.go new file mode 100644 index 0000000..a2ead55 --- /dev/null +++ b/pkclient/pkclient_cgo.go @@ -0,0 +1,229 @@ +//go:build cgo && pkcs11 + +package pkclient + +import ( + "encoding/asn1" + "errors" + "fmt" + "log" + "math/big" + + "github.com/miekg/pkcs11" + "github.com/miekg/pkcs11/p11" +) + +type PKClient struct { + module p11.Module + session p11.Session + id []byte + label []byte + privKeyObj p11.Object + pubKeyObj p11.Object +} + +type ecdsaSignature struct { + R, S *big.Int +} + +// New tries to open a session with the HSM, select the slot and login to it +func New(hsmPath string, slotId uint, pin string, id string, label string) (*PKClient, error) { + module, err := p11.OpenModule(hsmPath) + if err != nil { + return nil, fmt.Errorf("failed to load module library: %s", hsmPath) + } + + slots, err := module.Slots() + if err != nil { + module.Destroy() + return nil, err + } + + // Try to open a session on the slot + slotIdx := 0 + for i, slot := range slots { + if slot.ID() == slotId { + slotIdx = i + break + } + } + + client := &PKClient{ + module: module, + id: []byte(id), + label: []byte(label), + } + + client.session, err = slots[slotIdx].OpenWriteSession() + if err != nil { + module.Destroy() + return nil, fmt.Errorf("failed to open session on slot %d", slotId) + } + + if len(pin) != 0 { + err = client.session.Login(pin) + if err != nil { + // ignore "already logged in" + if !errors.Is(err, pkcs11.Error(256)) { + _ = client.session.Close() + return nil, fmt.Errorf("unable to login. error: %w", err) + } + } + } + + // Make sure the hsm has a private key for deriving + client.privKeyObj, err = client.findDeriveKey(client.id, client.label, true) + if err != nil { + _ = client.Close() //log out, close session, destroy module + return nil, fmt.Errorf("failed to find private key for deriving: %w", err) + } + + return client, nil +} + +// Close cleans up properly and logs out +func (c *PKClient) Close() error { + var err error = nil + if c.session != nil { + _ = c.session.Logout() //if logout fails, we still want to close + err = c.session.Close() + } + + c.module.Destroy() + return err +} + +// Try to find a suitable key on the hsm for key derivation +// parameter GET_PUB_KEY sets the search pattern for a public or private key +func (c *PKClient) findDeriveKey(id []byte, label []byte, private bool) (key p11.Object, err error) { + keyClass := pkcs11.CKO_PRIVATE_KEY + if !private { + keyClass = pkcs11.CKO_PUBLIC_KEY + } + keyAttrs := []*pkcs11.Attribute{ + //todo, not all HSMs seem to report this, even if its true: pkcs11.NewAttribute(pkcs11.CKA_DERIVE, true), + pkcs11.NewAttribute(pkcs11.CKA_CLASS, keyClass), + } + + if id != nil && len(id) != 0 { + keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) + } + if label != nil && len(label) != 0 { + keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) + } + + return c.session.FindObject(keyAttrs) +} + +func (c *PKClient) listDeriveKeys(id []byte, label []byte, private bool) { + keyClass := pkcs11.CKO_PRIVATE_KEY + if !private { + keyClass = pkcs11.CKO_PUBLIC_KEY + } + keyAttrs := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_CLASS, keyClass), + } + + if id != nil && len(id) != 0 { + keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) + } + if label != nil && len(label) != 0 { + keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) + } + + objects, err := c.session.FindObjects(keyAttrs) + if err != nil { + return + } + + for _, obj := range objects { + l, err := obj.Label() + log.Printf("%s, %v", l, err) + a, err := obj.Attribute(pkcs11.CKA_DERIVE) + log.Printf("DERIVE: %s %v, %v", l, a, err) + } +} + +// SignASN1 signs some data. Returns the ASN.1 encoded signature. +func (c *PKClient) SignASN1(data []byte) ([]byte, error) { + mech := pkcs11.NewMechanism(pkcs11.CKM_ECDSA_SHA256, nil) + sk := p11.PrivateKey(c.privKeyObj) + rawSig, err := sk.Sign(*mech, data) + if err != nil { + return nil, err + } + + // PKCS #11 Mechanisms v2.30: + // "The signature octets correspond to the concatenation of the ECDSA values r and s, + // both represented as an octet string of equal length of at most nLen with the most + // significant byte first. If r and s have different octet length, the shorter of both + // must be padded with leading zero octets such that both have the same octet length. + // Loosely spoken, the first half of the signature is r and the second half is s." + r := new(big.Int).SetBytes(rawSig[:len(rawSig)/2]) + s := new(big.Int).SetBytes(rawSig[len(rawSig)/2:]) + return asn1.Marshal(ecdsaSignature{r, s}) +} + +// DeriveNoise derives a shared secret using the input public key against the private key that was found during setup. +// Returns a fixed 32 byte array. +func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) { + // Before we call derive, we need to have an array of attributes which specify the type of + // key to be returned, in our case, it's the shared secret key, produced via deriving + // This template pulled from OpenSC pkclient-tool.c line 4038 + attrTemplate := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_TOKEN, false), + pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_SECRET_KEY), + pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, pkcs11.CKK_GENERIC_SECRET), + pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, false), + pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, true), + pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, true), + pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), + pkcs11.NewAttribute(pkcs11.CKA_WRAP, true), + pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true), + } + + // Set up the parameters which include the peer's public key + ecdhParams := pkcs11.NewECDH1DeriveParams(pkcs11.CKD_NULL, nil, peerPubKey) + mech := pkcs11.NewMechanism(pkcs11.CKM_ECDH1_DERIVE, ecdhParams) + sk := p11.PrivateKey(c.privKeyObj) + + tmpKey, err := sk.Derive(*mech, attrTemplate) + if err != nil { + return nil, err + } + if tmpKey == nil || len(tmpKey) == 0 { + return nil, fmt.Errorf("got an empty secret key") + } + secret := make([]byte, NoiseKeySize) + copy(secret[:], tmpKey[:NoiseKeySize]) + return secret, nil +} + +func (c *PKClient) GetPubKey() ([]byte, error) { + d, err := c.privKeyObj.Attribute(pkcs11.CKA_PUBLIC_KEY_INFO) + if err != nil { + return nil, err + } + if d != nil && len(d) > 0 { + return formatPubkeyFromPublicKeyInfoAttr(d) + } + c.pubKeyObj, err = c.findDeriveKey(c.id, c.label, false) + if err != nil { + return nil, fmt.Errorf("pkcs11 module gave us a nil CKA_PUBLIC_KEY_INFO, and looking up the public key also failed: %w", err) + } + d, err = c.pubKeyObj.Attribute(pkcs11.CKA_EC_POINT) + if err != nil { + return nil, fmt.Errorf("pkcs11 module gave us a nil CKA_PUBLIC_KEY_INFO, and reading CKA_EC_POINT also failed: %w", err) + } + if d == nil || len(d) < 1 { + return nil, fmt.Errorf("pkcs11 module gave us a nil or empty CKA_EC_POINT") + } + switch len(d) { + case 65: //length of 0x04 + len(X) + len(Y) + return d, nil + case 67: //as above, DER-encoded IIRC? + return d[2:], nil + default: + return nil, fmt.Errorf("unknown public key length: %d", len(d)) + } +} diff --git a/pkclient/pkclient_stub.go b/pkclient/pkclient_stub.go new file mode 100644 index 0000000..36b0fc9 --- /dev/null +++ b/pkclient/pkclient_stub.go @@ -0,0 +1,30 @@ +//go:build !cgo || !pkcs11 + +package pkclient + +import "errors" + +type PKClient struct { +} + +var notImplemented = errors.New("not implemented") + +func New(hsmPath string, slotId uint, pin string, id string, label string) (*PKClient, error) { + return nil, notImplemented +} + +func (c *PKClient) Close() error { + return nil +} + +func (c *PKClient) SignASN1(data []byte) ([]byte, error) { + return nil, notImplemented +} + +func (c *PKClient) DeriveNoise(_ []byte) ([]byte, error) { + return nil, notImplemented +} + +func (c *PKClient) GetPubKey() ([]byte, error) { + return nil, notImplemented +} diff --git a/pki.go b/pki.go index 91478ce..888da7c 100644 --- a/pki.go +++ b/pki.go @@ -1,13 +1,19 @@ package nebula import ( + "encoding/binary" + "encoding/json" "errors" "fmt" + "net" + "net/netip" "os" + "slices" "strings" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" @@ -16,16 +22,27 @@ import ( type PKI struct { cs atomic.Pointer[CertState] - caPool atomic.Pointer[cert.NebulaCAPool] + caPool atomic.Pointer[cert.CAPool] l *logrus.Logger } type CertState struct { - Certificate *cert.NebulaCertificate - RawCertificate []byte - RawCertificateNoKey []byte - PublicKey []byte - PrivateKey []byte + v1Cert cert.Certificate + v1HandshakeBytes []byte + + v2Cert cert.Certificate + v2HandshakeBytes []byte + + defaultVersion cert.Version + privateKey []byte + pkcs11Backed bool + cipher string + + myVpnNetworks []netip.Prefix + myVpnNetworksTable *bart.Table[struct{}] + myVpnAddrs []netip.Addr + myVpnAddrsTable *bart.Table[struct{}] + myVpnBroadcastAddrsTable *bart.Table[struct{}] } func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { @@ -45,16 +62,16 @@ func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { return pki, nil } -func (p *PKI) GetCertState() *CertState { - return p.cs.Load() -} - -func (p *PKI) GetCAPool() *cert.NebulaCAPool { +func (p *PKI) GetCAPool() *cert.CAPool { return p.caPool.Load() } +func (p *PKI) getCertState() *CertState { + return p.cs.Load() +} + func (p *PKI) reload(c *config.C, initial bool) error { - err := p.reloadCert(c, initial) + err := p.reloadCerts(c, initial) if err != nil { if initial { return err @@ -73,31 +90,94 @@ func (p *PKI) reload(c *config.C, initial bool) error { return nil } -func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { - cs, err := newCertStateFromConfig(c) +func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { + newState, err := newCertStateFromConfig(c) if err != nil { return util.NewContextualError("Could not load client cert", nil, err) } if !initial { - // did IP in cert change? if so, don't set - currentCert := p.cs.Load().Certificate - oldIPs := currentCert.Details.Ips - newIPs := cs.Certificate.Details.Ips - if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { + currentState := p.cs.Load() + if newState.v1Cert != nil { + if currentState.v1Cert == nil { + return util.NewContextualError("v1 certificate was added, restart required", nil, err) + } + + // did IP in cert change? if so, don't set + if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) { + return util.NewContextualError( + "Networks in new cert was different from old", + m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()}, + nil, + ) + } + + if currentState.v1Cert.Curve() != newState.v1Cert.Curve() { + return util.NewContextualError( + "Curve in new cert was different from old", + m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()}, + nil, + ) + } + + } else if currentState.v1Cert != nil { + //TODO: CERT-V2 we should be able to tear this down + return util.NewContextualError("v1 certificate was removed, restart required", nil, err) + } + + if newState.v2Cert != nil { + if currentState.v2Cert == nil { + return util.NewContextualError("v2 certificate was added, restart required", nil, err) + } + + // did IP in cert change? if so, don't set + if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) { + return util.NewContextualError( + "Networks in new cert was different from old", + m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()}, + nil, + ) + } + + if currentState.v2Cert.Curve() != newState.v2Cert.Curve() { + return util.NewContextualError( + "Curve in new cert was different from old", + m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()}, + nil, + ) + } + + } else if currentState.v2Cert != nil { + return util.NewContextualError("v2 certificate was removed, restart required", nil, err) + } + + // Cipher cant be hot swapped so just leave it at what it was before + newState.cipher = currentState.cipher + + } else { + newState.cipher = c.GetString("cipher", "aes") + //TODO: this sucks and we should make it not a global + switch newState.cipher { + case "aes": + noiseEndianness = binary.BigEndian + case "chachapoly": + noiseEndianness = binary.LittleEndian + default: return util.NewContextualError( - "IP in new cert was different from old", - m{"new_ip": newIPs[0], "old_ip": oldIPs[0]}, + "unknown cipher", + m{"cipher": newState.cipher}, nil, ) } } - p.cs.Store(cs) + p.cs.Store(newState) + + //TODO: CERT-V2 newState needs a stringer that does json if initial { - p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate") + p.l.WithField("cert", newState).Debug("Client nebula certificate(s)") } else { - p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk") + p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk") } return nil } @@ -113,34 +193,68 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { return nil } -func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) { - // Marshal the certificate to ensure it is valid - rawCertificate, err := certificate.Marshal() - if err != nil { - return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) +func (cs *CertState) GetDefaultCertificate() cert.Certificate { + c := cs.getCertificate(cs.defaultVersion) + if c == nil { + panic("No default certificate found") + } + return c +} + +func (cs *CertState) getCertificate(v cert.Version) cert.Certificate { + switch v { + case cert.Version1: + return cs.v1Cert + case cert.Version2: + return cs.v2Cert } - publicKey := certificate.Details.PublicKey - cs := &CertState{ - RawCertificate: rawCertificate, - Certificate: certificate, - PrivateKey: privateKey, - PublicKey: publicKey, + return nil +} + +// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version. +// Callers must check if the return []byte is nil. +func (cs *CertState) getHandshakeBytes(v cert.Version) []byte { + switch v { + case cert.Version1: + return cs.v1HandshakeBytes + case cert.Version2: + return cs.v2HandshakeBytes + default: + return nil + } +} + +func (cs *CertState) String() string { + b, err := cs.MarshalJSON() + if err != nil { + return fmt.Sprintf("error marshaling certificate state: %v", err) + } + return string(b) +} + +func (cs *CertState) MarshalJSON() ([]byte, error) { + msg := []json.RawMessage{} + if cs.v1Cert != nil { + b, err := cs.v1Cert.MarshalJSON() + if err != nil { + return nil, err + } + msg = append(msg, b) } - cs.Certificate.Details.PublicKey = nil - rawCertNoKey, err := cs.Certificate.Marshal() - if err != nil { - return nil, fmt.Errorf("error marshalling certificate no key: %s", err) + if cs.v2Cert != nil { + b, err := cs.v2Cert.MarshalJSON() + if err != nil { + return nil, err + } + msg = append(msg, b) } - cs.RawCertificateNoKey = rawCertNoKey - // put public key back - cs.Certificate.Details.PublicKey = cs.PublicKey - return cs, nil + + return json.Marshal(msg) } func newCertStateFromConfig(c *config.C) (*CertState, error) { - var pemPrivateKey []byte var err error privPathOrPEM := c.GetString("pki.key", "") @@ -148,20 +262,9 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { return nil, errors.New("no pki.key path or PEM data provided") } - if strings.Contains(privPathOrPEM, "-----BEGIN") { - pemPrivateKey = []byte(privPathOrPEM) - privPathOrPEM = "" - - } else { - pemPrivateKey, err = os.ReadFile(privPathOrPEM) - if err != nil { - return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) - } - } - - rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey) + rawKey, curve, isPkcs11, err := loadPrivateKey(privPathOrPEM) if err != nil { - return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + return nil, err } var rawCert []byte @@ -182,27 +285,200 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { } } - nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) - if err != nil { - return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) + var crt, v1, v2 cert.Certificate + for { + // Load the certificate + crt, rawCert, err = loadCertificate(rawCert) + if err != nil { + return nil, err + } + + switch crt.Version() { + case cert.Version1: + if v1 != nil { + return nil, fmt.Errorf("v1 certificate already found in pki.cert") + } + v1 = crt + case cert.Version2: + if v2 != nil { + return nil, fmt.Errorf("v2 certificate already found in pki.cert") + } + v2 = crt + default: + return nil, fmt.Errorf("unknown certificate version %v", crt.Version()) + } + + if len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" { + break + } } - if nebulaCert.Expired(time.Now()) { - return nil, fmt.Errorf("nebula certificate for this host is expired") + if v1 == nil && v2 == nil { + return nil, errors.New("no certificates found in pki.cert") } - if len(nebulaCert.Details.Ips) == 0 { - return nil, fmt.Errorf("no IPs encoded in certificate") + useDefaultVersion := uint32(1) + if v1 == nil { + // The only condition that requires v2 as the default is if only a v2 certificate is present + // We do this to avoid having to configure it specifically in the config file + useDefaultVersion = 2 } - if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { - return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion) + var defaultVersion cert.Version + switch rawDefaultVersion { + case 1: + if v1 == nil { + return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert") + } + defaultVersion = cert.Version1 + case 2: + defaultVersion = cert.Version2 + default: + return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion) } - return newCertState(nebulaCert, rawKey) + return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey) } -func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) { +func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { + cs := CertState{ + privateKey: privateKey, + pkcs11Backed: pkcs11backed, + myVpnNetworksTable: new(bart.Table[struct{}]), + myVpnAddrsTable: new(bart.Table[struct{}]), + myVpnBroadcastAddrsTable: new(bart.Table[struct{}]), + } + + if v1 != nil && v2 != nil { + if !slices.Equal(v1.PublicKey(), v2.PublicKey()) { + return nil, util.NewContextualError("v1 and v2 public keys are not the same, ignoring", nil, nil) + } + + if v1.Curve() != v2.Curve() { + return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil) + } + + //TODO: CERT-V2 make sure v2 has v1s address + + cs.defaultVersion = dv + } + + if v1 != nil { + if pkcs11backed { + //NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm + } else { + if err := v1.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil { + return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + } + } + + v1hs, err := v1.MarshalForHandshakes() + if err != nil { + return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) + } + cs.v1Cert = v1 + cs.v1HandshakeBytes = v1hs + + if cs.defaultVersion == 0 { + cs.defaultVersion = cert.Version1 + } + } + + if v2 != nil { + if pkcs11backed { + //NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm + } else { + if err := v2.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil { + return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + } + } + + v2hs, err := v2.MarshalForHandshakes() + if err != nil { + return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) + } + cs.v2Cert = v2 + cs.v2HandshakeBytes = v2hs + + if cs.defaultVersion == 0 { + cs.defaultVersion = cert.Version2 + } + } + + var crt cert.Certificate + crt = cs.getCertificate(cert.Version2) + if crt == nil { + // v2 certificates are a superset, only look at v1 if its all we have + crt = cs.getCertificate(cert.Version1) + } + + for _, network := range crt.Networks() { + cs.myVpnNetworks = append(cs.myVpnNetworks, network) + cs.myVpnNetworksTable.Insert(network, struct{}{}) + + cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr()) + cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{}) + + if network.Addr().Is4() { + addr := network.Masked().Addr().As4() + mask := net.CIDRMask(network.Bits(), network.Addr().BitLen()) + binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask)) + cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{}) + } + } + + return &cs, nil +} + +func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) { + var pemPrivateKey []byte + if strings.Contains(privPathOrPEM, "-----BEGIN") { + pemPrivateKey = []byte(privPathOrPEM) + privPathOrPEM = "" + rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) + if err != nil { + return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + } else if strings.HasPrefix(privPathOrPEM, "pkcs11:") { + rawKey = []byte(privPathOrPEM) + return rawKey, cert.Curve_P256, true, nil + } else { + pemPrivateKey, err = os.ReadFile(privPathOrPEM) + if err != nil { + return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) + } + rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) + if err != nil { + return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + } + + return +} + +func loadCertificate(b []byte) (cert.Certificate, []byte, error) { + c, b, err := cert.UnmarshalCertificateFromPEM(b) + if err != nil { + return nil, b, fmt.Errorf("error while unmarshaling pki.cert: %w", err) + } + + if c.Expired(time.Now()) { + return nil, b, fmt.Errorf("nebula certificate for this host is expired") + } + + if len(c.Networks()) == 0 { + return nil, b, fmt.Errorf("no networks encoded in certificate") + } + + if c.IsCA() { + return nil, b, fmt.Errorf("host certificate is a CA certificate") + } + + return c, b, nil +} + +func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { var rawCA []byte var err error @@ -221,11 +497,11 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, er } } - caPool, err := cert.NewCAPoolFromBytes(rawCA) + caPool, err := cert.NewCAPoolFromPEM(rawCA) if errors.Is(err, cert.ErrExpired) { var expired int for _, crt := range caPool.CAs { - if crt.Expired(time.Now()) { + if crt.Certificate.Expired(time.Now()) { expired++ l.WithField("cert", crt).Warn("expired certificate present in CA pool") } diff --git a/punchy_test.go b/punchy_test.go index bedd2b2..56dd1c2 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -7,6 +7,7 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewPunchyFromConfig(t *testing.T) { @@ -15,39 +16,39 @@ func TestNewPunchyFromConfig(t *testing.T) { // Test defaults p := NewPunchyFromConfig(l, c) - assert.Equal(t, false, p.GetPunch()) - assert.Equal(t, false, p.GetRespond()) + assert.False(t, p.GetPunch()) + assert.False(t, p.GetRespond()) assert.Equal(t, time.Second, p.GetDelay()) assert.Equal(t, 5*time.Second, p.GetRespondDelay()) // punchy deprecation c.Settings["punchy"] = true p = NewPunchyFromConfig(l, c) - assert.Equal(t, true, p.GetPunch()) + assert.True(t, p.GetPunch()) // punchy.punch - c.Settings["punchy"] = map[interface{}]interface{}{"punch": true} + c.Settings["punchy"] = map[string]any{"punch": true} p = NewPunchyFromConfig(l, c) - assert.Equal(t, true, p.GetPunch()) + assert.True(t, p.GetPunch()) // punch_back deprecation c.Settings["punch_back"] = true p = NewPunchyFromConfig(l, c) - assert.Equal(t, true, p.GetRespond()) + assert.True(t, p.GetRespond()) // punchy.respond - c.Settings["punchy"] = map[interface{}]interface{}{"respond": true} + c.Settings["punchy"] = map[string]any{"respond": true} c.Settings["punch_back"] = false p = NewPunchyFromConfig(l, c) - assert.Equal(t, true, p.GetRespond()) + assert.True(t, p.GetRespond()) // punchy.delay - c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"} + c.Settings["punchy"] = map[string]any{"delay": "1m"} p = NewPunchyFromConfig(l, c) assert.Equal(t, time.Minute, p.GetDelay()) // punchy.respond_delay - c.Settings["punchy"] = map[interface{}]interface{}{"respond_delay": "1m"} + c.Settings["punchy"] = map[string]any{"respond_delay": "1m"} p = NewPunchyFromConfig(l, c) assert.Equal(t, time.Minute, p.GetRespondDelay()) } @@ -56,22 +57,22 @@ func TestPunchy_reload(t *testing.T) { l := test.NewLogger() c := config.NewC(l) delay, _ := time.ParseDuration("1m") - assert.NoError(t, c.LoadString(` + require.NoError(t, c.LoadString(` punchy: delay: 1m respond: false `)) p := NewPunchyFromConfig(l, c) assert.Equal(t, delay, p.GetDelay()) - assert.Equal(t, false, p.GetRespond()) + assert.False(t, p.GetRespond()) newDelay, _ := time.ParseDuration("10m") - assert.NoError(t, c.ReloadConfigString(` + require.NoError(t, c.ReloadConfigString(` punchy: delay: 10m respond: true `)) p.reload(c, false) assert.Equal(t, newDelay, p.GetDelay()) - assert.Equal(t, true, p.GetRespond()) + assert.True(t, p.GetRespond()) } diff --git a/relay_manager.go b/relay_manager.go index 7aa06cc..7565350 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -2,14 +2,16 @@ package nebula import ( "context" + "encoding/binary" "errors" "fmt" + "net/netip" "sync/atomic" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" ) type relayManager struct { @@ -50,7 +52,7 @@ func (rm *relayManager) setAmRelay(v bool) { // AddRelay finds an available relay index on the hostmap, and associates the relay info with it. // relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp. -func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iputil.VpnIp, remoteIdx *uint32, relayType int, state int) (uint32, error) { +func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) { hm.Lock() defer hm.Unlock() for i := 0; i < 32; i++ { @@ -71,7 +73,7 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iput Type: relayType, State: state, LocalIndex: index, - PeerIp: vpnIp, + PeerAddr: vpnIp, } if remoteIdx != nil { @@ -90,36 +92,71 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iput func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) { relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex) if !ok { - rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnIp, + fields := logrus.Fields{ + "relay": relayHostInfo.vpnAddrs[0], "initiatorRelayIndex": m.InitiatorRelayIndex, - "relayFrom": m.RelayFromIp, - "relayTo": m.RelayToIp}).Info("relayManager failed to update relay") + } + + if m.RelayFromAddr == nil { + fields["relayFrom"] = m.OldRelayFromAddr + } else { + fields["relayFrom"] = m.RelayFromAddr + } + + if m.RelayToAddr == nil { + fields["relayTo"] = m.OldRelayToAddr + } else { + fields["relayTo"] = m.RelayToAddr + } + + rm.l.WithFields(fields).Info("relayManager failed to update relay") return nil, fmt.Errorf("unknown relay") } return relay, nil } -func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Interface) { - - switch m.Type { - case NebulaControl_CreateRelayRequest: - rm.handleCreateRelayRequest(h, f, m) - case NebulaControl_CreateRelayResponse: - rm.handleCreateRelayResponse(h, f, m) +func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) { + msg := &NebulaControl{} + err := msg.Unmarshal(d) + if err != nil { + h.logger(f.l).WithError(err).Error("Failed to unmarshal control message") + return } + var v cert.Version + if msg.OldRelayFromAddr > 0 || msg.OldRelayToAddr > 0 { + v = cert.Version1 + + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], msg.OldRelayFromAddr) + msg.RelayFromAddr = netAddrToProtoAddr(netip.AddrFrom4(b)) + + binary.BigEndian.PutUint32(b[:], msg.OldRelayToAddr) + msg.RelayToAddr = netAddrToProtoAddr(netip.AddrFrom4(b)) + } else { + v = cert.Version2 + } + + switch msg.Type { + case NebulaControl_CreateRelayRequest: + rm.handleCreateRelayRequest(v, h, f, msg) + case NebulaControl_CreateRelayResponse: + rm.handleCreateRelayResponse(v, h, f, msg) + } } -func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) { +func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(m.RelayFromIp), - "relayTo": iputil.VpnIp(m.RelayToIp), + "relayFrom": protoAddrToNetAddr(m.RelayFromAddr), + "relayTo": protoAddrToNetAddr(m.RelayToAddr), "initiatorRelayIndex": m.InitiatorRelayIndex, "responderRelayIndex": m.ResponderRelayIndex, - "vpnIp": h.vpnIp}). + "vpnAddrs": h.vpnAddrs}). Info("handleCreateRelayResponse") - target := iputil.VpnIp(m.RelayToIp) + + target := m.RelayToAddr + targetAddr := protoAddrToNetAddr(target) relay, err := rm.EstablishRelay(h, m) if err != nil { @@ -131,62 +168,88 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * return } // I'm the middle man. Let the initiator know that the I've established the relay they requested. - peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp) + peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr) if peerHostInfo == nil { - rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") + rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer") return } - peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target) + peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) if !ok { - rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo") + rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo") return } - if peerRelay.State == PeerRequested { - peerRelay.State = Established + switch peerRelay.State { + case Requested: + // I initiated the request to this peer, but haven't heard back from the peer yet. I must wait for this peer + // to respond to complete the connection. + case PeerRequested, Disestablished, Established: + peerHostInfo.relayState.UpdateRelayForByIpState(targetAddr, Established) resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: peerRelay.LocalIndex, InitiatorRelayIndex: peerRelay.RemoteIndex, - RelayFromIp: uint32(peerHostInfo.vpnIp), - RelayToIp: uint32(target), } + + if v == cert.Version1 { + peer := peerHostInfo.vpnAddrs[0] + if !peer.Is4() { + rm.l.WithField("relayFrom", peer). + WithField("relayTo", target). + WithField("initiatorRelayIndex", resp.InitiatorRelayIndex). + WithField("responderRelayIndex", resp.ResponderRelayIndex). + WithField("vpnAddrs", peerHostInfo.vpnAddrs). + Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address") + return + } + + b := peer.As4() + resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = targetAddr.As4() + resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + } else { + resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0]) + resp.RelayToAddr = target + } + msg, err := resp.Marshal() if err != nil { - rm.l. - WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") + rm.l.WithError(err). + Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") } else { f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(resp.RelayFromIp), - "relayTo": iputil.VpnIp(resp.RelayToIp), + "relayFrom": resp.RelayFromAddr, + "relayTo": resp.RelayToAddr, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, - "vpnIp": peerHostInfo.vpnIp}). + "vpnAddrs": peerHostInfo.vpnAddrs}). Info("send CreateRelayResponse") } } } -func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) { - - from := iputil.VpnIp(m.RelayFromIp) - target := iputil.VpnIp(m.RelayToIp) +func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { + from := protoAddrToNetAddr(m.RelayFromAddr) + target := protoAddrToNetAddr(m.RelayToAddr) logMsg := rm.l.WithFields(logrus.Fields{ "relayFrom": from, "relayTo": target, "initiatorRelayIndex": m.InitiatorRelayIndex, - "vpnIp": h.vpnIp}) + "vpnAddrs": h.vpnAddrs}) logMsg.Info("handleCreateRelayRequest") // Is the source of the relay me? This should never happen, but did happen due to // an issue migrating relays over to newly re-handshaked host info objects. - if from == f.myVpnIp { - logMsg.WithField("myIP", f.myVpnIp).Error("Discarding relay request from myself") + _, found := f.myVpnAddrsTable.Lookup(from) + if found { + logMsg.WithField("myIP", from).Error("Discarding relay request from myself") return } + // Is the target of the relay me? - if target == f.myVpnIp { + _, found = f.myVpnAddrsTable.Lookup(target) + if found { existingRelay, ok := h.relayState.QueryRelayForByIp(from) if ok { switch existingRelay.State { @@ -204,6 +267,21 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") return } + case Disestablished: + if existingRelay.RemoteIndex != m.InitiatorRelayIndex { + // We got a brand new Relay request, because its index is different than what we saw before. + // This should never happen. The peer should never change an index, once created. + logMsg.WithFields(logrus.Fields{ + "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") + return + } + // Mark the relay as 'Established' because it's safe to use again + h.relayState.UpdateRelayForByIpState(from, Established) + case PeerRequested: + // I should never be in this state, because I am terminal, not forwarding. + logMsg.WithFields(logrus.Fields{ + "existingRemoteIndex": existingRelay.RemoteIndex, + "state": existingRelay.State}).Error("Unexpected Relay State found") } } else { _, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established) @@ -215,7 +293,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N relay, ok := h.relayState.QueryRelayForByIp(from) if !ok { - logMsg.Error("Relay State not found") + logMsg.WithField("from", from).Error("Relay State not found") return } @@ -223,9 +301,18 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: relay.LocalIndex, InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: uint32(from), - RelayToIp: uint32(target), } + + if v == cert.Version1 { + b := from.As4() + resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = target.As4() + resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + } else { + resp.RelayFromAddr = netAddrToProtoAddr(from) + resp.RelayToAddr = netAddrToProtoAddr(target) + } + msg, err := resp.Marshal() if err != nil { logMsg. @@ -233,11 +320,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } else { f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(resp.RelayFromIp), - "relayTo": iputil.VpnIp(resp.RelayToIp), + "relayFrom": from, + "relayTo": target, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, - "vpnIp": h.vpnIp}). + "vpnAddrs": h.vpnAddrs}). Info("send CreateRelayResponse") } return @@ -246,110 +333,80 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N if !rm.GetAmRelay() { return } - peer := rm.hostmap.QueryVpnIp(target) + peer := rm.hostmap.QueryVpnAddr(target) if peer == nil { // Try to establish a connection to this host. If we get a future relay request, // we'll be ready! f.Handshake(target) return } - if peer.remote == nil { + if !peer.remote.IsValid() { // Only create relays to peers for whom I have a direct connection return } - sendCreateRequest := false var index uint32 var err error targetRelay, ok := peer.relayState.QueryRelayForByIp(from) if ok { index = targetRelay.LocalIndex - if targetRelay.State == Requested { - sendCreateRequest = true - } } else { // Allocate an index in the hostMap for this relay peer index, err = AddRelay(rm.l, peer, f.hostMap, from, nil, ForwardingType, Requested) if err != nil { return } - sendCreateRequest = true } - if sendCreateRequest { - // Send a CreateRelayRequest to the peer. - req := NebulaControl{ - Type: NebulaControl_CreateRelayRequest, - InitiatorRelayIndex: index, - RelayFromIp: uint32(h.vpnIp), - RelayToIp: uint32(target), - } - msg, err := req.Marshal() - if err != nil { - logMsg. - WithError(err).Error("relayManager Failed to marshal Control message to create relay") - } else { - f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(req.RelayFromIp), - "relayTo": iputil.VpnIp(req.RelayToIp), - "initiatorRelayIndex": req.InitiatorRelayIndex, - "responderRelayIndex": req.ResponderRelayIndex, - "vpnIp": target}). - Info("send CreateRelayRequest") - } + peer.relayState.UpdateRelayForByIpState(from, Requested) + // Send a CreateRelayRequest to the peer. + req := NebulaControl{ + Type: NebulaControl_CreateRelayRequest, + InitiatorRelayIndex: index, } + + if v == cert.Version1 { + if !h.vpnAddrs[0].Is4() { + rm.l.WithField("relayFrom", h.vpnAddrs[0]). + WithField("relayTo", target). + WithField("initiatorRelayIndex", req.InitiatorRelayIndex). + WithField("responderRelayIndex", req.ResponderRelayIndex). + WithField("vpnAddr", target). + Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address") + return + } + + b := h.vpnAddrs[0].As4() + req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = target.As4() + req.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + } else { + req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0]) + req.RelayToAddr = netAddrToProtoAddr(target) + } + + msg, err := req.Marshal() + if err != nil { + logMsg. + WithError(err).Error("relayManager Failed to marshal Control message to create relay") + } else { + f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) + rm.l.WithFields(logrus.Fields{ + "relayFrom": h.vpnAddrs[0], + "relayTo": target, + "initiatorRelayIndex": req.InitiatorRelayIndex, + "responderRelayIndex": req.ResponderRelayIndex, + "vpnAddr": target}). + Info("send CreateRelayRequest") + } + // Also track the half-created Relay state just received - relay, ok := h.relayState.QueryRelayForByIp(target) + _, ok = h.relayState.QueryRelayForByIp(target) if !ok { - // Add the relay - state := PeerRequested - if targetRelay != nil && targetRelay.State == Established { - state = Established - } - _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, state) + _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested) if err != nil { logMsg. WithError(err).Error("relayManager Failed to allocate a local index for relay") return } - } else { - switch relay.State { - case Established: - if relay.RemoteIndex != m.InitiatorRelayIndex { - // We got a brand new Relay request, because its index is different than what we saw before. - // This should never happen. The peer should never change an index, once created. - logMsg.WithFields(logrus.Fields{ - "existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") - return - } - resp := NebulaControl{ - Type: NebulaControl_CreateRelayResponse, - ResponderRelayIndex: relay.LocalIndex, - InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: uint32(h.vpnIp), - RelayToIp: uint32(target), - } - msg, err := resp.Marshal() - if err != nil { - rm.l. - WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") - } else { - f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(resp.RelayFromIp), - "relayTo": iputil.VpnIp(resp.RelayToIp), - "initiatorRelayIndex": resp.InitiatorRelayIndex, - "responderRelayIndex": resp.ResponderRelayIndex, - "vpnIp": h.vpnIp}). - Info("send CreateRelayResponse") - } - - case Requested: - // Keep waiting for the other relay to complete - } } } } - -func (rm *relayManager) RemoveRelay(localIdx uint32) { - rm.hostmap.RemoveRelay(localIdx) -} diff --git a/remote_list.go b/remote_list.go index b07d15c..8d5c6ae 100644 --- a/remote_list.go +++ b/remote_list.go @@ -1,26 +1,24 @@ package nebula import ( - "bytes" "context" "net" "net/netip" + "slices" "sort" "strconv" "sync/atomic" "time" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) // forEachFunc is used to benefit folks that want to do work inside the lock -type forEachFunc func(addr *udp.Addr, preferred bool) +type forEachFunc func(addr netip.AddrPort, preferred bool) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) -type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool -type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool +type checkFuncV4 func(vpnIp netip.Addr, to *V4AddrPort) bool +type checkFuncV6 func(vpnIp netip.Addr, to *V6AddrPort) bool // CacheMap is a struct that better represents the lighthouse cache for humans // The string key is the owners vpnIp @@ -29,14 +27,11 @@ type CacheMap map[string]*Cache // Cache is the other part of CacheMap to better represent the lighthouse cache for humans // We don't reason about ipv4 vs ipv6 here type Cache struct { - Learned []*udp.Addr `json:"learned,omitempty"` - Reported []*udp.Addr `json:"reported,omitempty"` - Relay []*net.IP `json:"relay"` + Learned []netip.AddrPort `json:"learned,omitempty"` + Reported []netip.AddrPort `json:"reported,omitempty"` + Relay []netip.Addr `json:"relay"` } -//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion -// We will never clean learned/reported information for them as it stands today - // cache is an internal struct that splits v4 and v6 addresses inside the cache map type cache struct { v4 *cacheV4 @@ -45,19 +40,19 @@ type cache struct { } type cacheRelay struct { - relay []uint32 + relay []netip.Addr } // cacheV4 stores learned and reported ipv4 records under cache type cacheV4 struct { - learned *Ip4AndPort - reported []*Ip4AndPort + learned *V4AddrPort + reported []*V4AddrPort } // cacheV4 stores learned and reported ipv6 records under cache type cacheV6 struct { - learned *Ip6AndPort - reported []*Ip6AndPort + learned *V6AddrPort + reported []*V6AddrPort } type hostnamePort struct { @@ -129,7 +124,7 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, continue } for _, a := range addrs { - netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{} + netipAddrs[netip.AddrPortFrom(a.Unmap(), hostPort.port)] = struct{}{} } } origSet := r.ips.Load() @@ -172,7 +167,7 @@ func (hr *hostnamesResults) Cancel() { } } -func (hr *hostnamesResults) GetIPs() []netip.AddrPort { +func (hr *hostnamesResults) GetAddrs() []netip.AddrPort { var retSlice []netip.AddrPort if hr != nil { p := hr.ips.Load() @@ -191,37 +186,43 @@ type RemoteList struct { // Every interaction with internals requires a lock! syncRWMutex + // The full list of vpn addresses assigned to this host + vpnAddrs []netip.Addr + // A deduplicated set of addresses. Any accessor should lock beforehand. - addrs []*udp.Addr + addrs []netip.AddrPort // A set of relay addresses. VpnIp addresses that the remote identified as relays. - relays []*iputil.VpnIp + relays []netip.Addr // These are maps to store v4 and v6 addresses per lighthouse // Map key is the vpnIp of the person that told us about this the cached entries underneath. // For learned addresses, this is the vpnIp that sent the packet - cache map[iputil.VpnIp]*cache + cache map[netip.Addr]*cache hr *hostnamesResults shouldAdd func(netip.Addr) bool // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // They should not be tried again during a handshake - badRemotes []*udp.Addr + badRemotes []netip.AddrPort // A flag that the cache may have changed and addrs needs to be rebuilt shouldRebuild bool } // NewRemoteList creates a new empty RemoteList -func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList { - return &RemoteList{ +func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList { + r := &RemoteList{ syncRWMutex: newSyncRWMutex("remote-list"), - addrs: make([]*udp.Addr, 0), - relays: make([]*iputil.VpnIp, 0), - cache: make(map[iputil.VpnIp]*cache), + vpnAddrs: make([]netip.Addr, len(vpnAddrs)), + addrs: make([]netip.AddrPort, 0), + relays: make([]netip.Addr, 0), + cache: make(map[netip.Addr]*cache), shouldAdd: shouldAdd, } + copy(r.vpnAddrs, vpnAddrs) + return r } func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { @@ -232,7 +233,7 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { // Len locks and reports the size of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges -func (r *RemoteList) Len(preferredRanges []*net.IPNet) int { +func (r *RemoteList) Len(preferredRanges []netip.Prefix) int { r.Rebuild(preferredRanges) r.RLock() defer r.RUnlock() @@ -241,18 +242,18 @@ func (r *RemoteList) Len(preferredRanges []*net.IPNet) int { // ForEach locks and will call the forEachFunc for every deduplicated address in the list // The deduplication work may need to occur here, so you must pass preferredRanges -func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) { +func (r *RemoteList) ForEach(preferredRanges []netip.Prefix, forEach forEachFunc) { r.Rebuild(preferredRanges) r.RLock() for _, v := range r.addrs { - forEach(v, isPreferred(v.IP, preferredRanges)) + forEach(v, isPreferred(v.Addr(), preferredRanges)) } r.RUnlock() } // CopyAddrs locks and makes a deep copy of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges -func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { +func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort { if r == nil { return nil } @@ -261,9 +262,9 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { r.RLock() defer r.RUnlock() - c := make([]*udp.Addr, len(r.addrs)) + c := make([]netip.AddrPort, len(r.addrs)) for i, v := range r.addrs { - c[i] = v.Copy() + c[i] = v } return c } @@ -271,14 +272,13 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { // LearnRemote locks and sets the learned slot for the owner vpn ip to the provided addr // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming. // It will mark the deduplicated address list as dirty, so do not call it unless new information is available -// TODO: this needs to support the allow list list -func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) { +func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) { r.Lock() defer r.Unlock() - if v4 := addr.IP.To4(); v4 != nil { - r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port))) + if remote.Addr().Is4() { + r.unlockedSetLearnedV4(ownerVpnIp, netAddrToProtoV4AddrPort(remote.Addr(), remote.Port())) } else { - r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port))) + r.unlockedSetLearnedV6(ownerVpnIp, netAddrToProtoV6AddrPort(remote.Addr(), remote.Port())) } } @@ -293,9 +293,9 @@ func (r *RemoteList) CopyCache() *CacheMap { c := cm[vpnIp] if c == nil { c = &Cache{ - Learned: make([]*udp.Addr, 0), - Reported: make([]*udp.Addr, 0), - Relay: make([]*net.IP, 0), + Learned: make([]netip.AddrPort, 0), + Reported: make([]netip.AddrPort, 0), + Relay: make([]netip.Addr, 0), } cm[vpnIp] = c } @@ -307,28 +307,27 @@ func (r *RemoteList) CopyCache() *CacheMap { if mc.v4 != nil { if mc.v4.learned != nil { - c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned)) + c.Learned = append(c.Learned, protoV4AddrPortToNetAddrPort(mc.v4.learned)) } for _, a := range mc.v4.reported { - c.Reported = append(c.Reported, NewUDPAddrFromLH4(a)) + c.Reported = append(c.Reported, protoV4AddrPortToNetAddrPort(a)) } } if mc.v6 != nil { if mc.v6.learned != nil { - c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned)) + c.Learned = append(c.Learned, protoV6AddrPortToNetAddrPort(mc.v6.learned)) } for _, a := range mc.v6.reported { - c.Reported = append(c.Reported, NewUDPAddrFromLH6(a)) + c.Reported = append(c.Reported, protoV6AddrPortToNetAddrPort(a)) } } if mc.relay != nil { for _, a := range mc.relay.relay { - nip := iputil.VpnIp(a).ToIP() - c.Relay = append(c.Relay, &nip) + c.Relay = append(c.Relay, a) } } } @@ -337,8 +336,8 @@ func (r *RemoteList) CopyCache() *CacheMap { } // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list -func (r *RemoteList) BlockRemote(bad *udp.Addr) { - if bad == nil { +func (r *RemoteList) BlockRemote(bad netip.AddrPort) { + if !bad.IsValid() { // relays can have nil udp Addrs return } @@ -351,20 +350,20 @@ func (r *RemoteList) BlockRemote(bad *udp.Addr) { } // We copy here because we are taking something else's memory and we can't trust everything - r.badRemotes = append(r.badRemotes, bad.Copy()) + r.badRemotes = append(r.badRemotes, bad) // Mark the next interaction must recollect/dedupe r.shouldRebuild = true } // CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list -func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr { +func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort { r.RLock() defer r.RUnlock() - c := make([]*udp.Addr, len(r.badRemotes)) + c := make([]netip.AddrPort, len(r.badRemotes)) for i, v := range r.badRemotes { - c[i] = v.Copy() + c[i] = v } return c } @@ -378,12 +377,11 @@ func (r *RemoteList) ResetBlockedRemotes() { // Rebuild locks and generates the deduplicated address list only if there is work to be done // There is generally no reason to call this directly but it is safe to do so -func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) { +func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) { r.Lock() defer r.Unlock() // Only rebuild if the cache changed - //TODO: shouldRebuild is probably pointless as we don't check for actual change when lighthouse updates come in if r.shouldRebuild { r.unlockedCollect() r.shouldRebuild = false @@ -394,9 +392,9 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) { } // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list -func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool { +func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool { for _, v := range r.badRemotes { - if v.Equals(remote) { + if v == remote { return true } } @@ -405,14 +403,14 @@ func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool { // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) { +func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *V4AddrPort) { r.shouldRebuild = true r.unlockedGetOrMakeV4(ownerVpnIp).learned = to } // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) { +func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*V4AddrPort, check checkFuncV4) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -427,7 +425,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, } } -func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []uint32) { +func (r *RemoteList) unlockedSetRelay(ownerVpnIp netip.Addr, to []netip.Addr) { r.shouldRebuild = true c := r.unlockedGetOrMakeRelay(ownerVpnIp) @@ -440,12 +438,12 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnI // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) { +func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *V4AddrPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) // We are doing the easy append because this is rarely called - c.reported = append([]*Ip4AndPort{to}, c.reported...) + c.reported = append([]*V4AddrPort{to}, c.reported...) if len(c.reported) > MaxRemotes { c.reported = c.reported[:MaxRemotes] } @@ -453,14 +451,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) { +func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *V6AddrPort) { r.shouldRebuild = true r.unlockedGetOrMakeV6(ownerVpnIp).learned = to } // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) { +func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*V6AddrPort, check checkFuncV6) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -477,18 +475,18 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) { +func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *V6AddrPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) // We are doing the easy append because this is rarely called - c.reported = append([]*Ip6AndPort{to}, c.reported...) + c.reported = append([]*V6AddrPort{to}, c.reported...) if len(c.reported) > MaxRemotes { c.reported = c.reported[:MaxRemotes] } } -func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay { +func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp netip.Addr) *cacheRelay { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} @@ -503,7 +501,7 @@ func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay // unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established. // The caller must dirty the learned address cache if required -func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 { +func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp netip.Addr) *cacheV4 { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} @@ -518,7 +516,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 { // unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established. // The caller must dirty the learned address cache if required -func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 { +func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp netip.Addr) *cacheV6 { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} @@ -540,14 +538,14 @@ func (r *RemoteList) unlockedCollect() { for _, c := range r.cache { if c.v4 != nil { if c.v4.learned != nil { - u := NewUDPAddrFromLH4(c.v4.learned) + u := protoV4AddrPortToNetAddrPort(c.v4.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v4.reported { - u := NewUDPAddrFromLH4(v) + u := protoV4AddrPortToNetAddrPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -556,14 +554,14 @@ func (r *RemoteList) unlockedCollect() { if c.v6 != nil { if c.v6.learned != nil { - u := NewUDPAddrFromLH6(c.v6.learned) + u := protoV6AddrPortToNetAddrPort(c.v6.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v6.reported { - u := NewUDPAddrFromLH6(v) + u := protoV6AddrPortToNetAddrPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -572,20 +570,17 @@ func (r *RemoteList) unlockedCollect() { if c.relay != nil { for _, v := range c.relay.relay { - ip := iputil.VpnIp(v) - relays = append(relays, &ip) + relays = append(relays, v) } } } - dnsAddrs := r.hr.GetIPs() + dnsAddrs := r.hr.GetAddrs() for _, addr := range dnsAddrs { if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { - v6 := addr.Addr().As16() - addrs = append(addrs, &udp.Addr{ - IP: v6[:], - Port: addr.Port(), - }) + if !r.unlockedIsBad(addr) { + addrs = append(addrs, addr) + } } } @@ -595,7 +590,22 @@ func (r *RemoteList) unlockedCollect() { } // unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list -func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { +func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) { + // Use a map to deduplicate any relay addresses + dedupedRelays := map[netip.Addr]struct{}{} + for _, relay := range r.relays { + dedupedRelays[relay] = struct{}{} + } + r.relays = r.relays[:0] + for relay := range dedupedRelays { + r.relays = append(r.relays, relay) + } + // Put them in a somewhat consistent order after de-duplication + slices.SortFunc(r.relays, func(a, b netip.Addr) int { + return a.Compare(b) + }) + + // Now the addrs n := len(r.addrs) if n < 2 { return @@ -606,8 +616,8 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { b := r.addrs[j] // Preferred addresses first - aPref := isPreferred(a.IP, preferredRanges) - bPref := isPreferred(b.IP, preferredRanges) + aPref := isPreferred(a.Addr(), preferredRanges) + bPref := isPreferred(b.Addr(), preferredRanges) switch { case aPref && !bPref: // If i is preferred and j is not, i is less than j @@ -622,21 +632,21 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { } // ipv6 addresses 2nd - a4 := a.IP.To4() - b4 := b.IP.To4() + a4 := a.Addr().Is4() + b4 := b.Addr().Is4() switch { - case a4 == nil && b4 != nil: + case a4 == false && b4 == true: // If i is v6 and j is v4, i is less than j return true - case a4 != nil && b4 == nil: + case a4 == true && b4 == false: // If j is v6 and i is v4, i is not less than j return false - case a4 != nil && b4 != nil: - // Special case for ipv4, a4 and b4 are not nil - aPrivate := isPrivateIP(a4) - bPrivate := isPrivateIP(b4) + case a4 == true && b4 == true: + // i and j are both ipv4 + aPrivate := a.Addr().IsPrivate() + bPrivate := b.Addr().IsPrivate() switch { case !aPrivate && bPrivate: // If i is a public ip (not private) and j is a private ip, i is less then j @@ -655,10 +665,10 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { } // lexical order of ips 3rd - c := bytes.Compare(a.IP, b.IP) + c := a.Addr().Compare(b.Addr()) if c == 0 { // Ips are the same, Lexical order of ports 4th - return a.Port < b.Port + return a.Port() < b.Port() } // Ip wasn't the same @@ -671,7 +681,7 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { // Deduplicate a, b := 0, 1 for b < n { - if !r.addrs[a].Equals(r.addrs[b]) { + if r.addrs[a] != r.addrs[b] { a++ if a != b { r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a] @@ -693,8 +703,7 @@ func minInt(a, b int) int { } // isPreferred returns true of the ip is contained in the preferredRanges list -func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool { - //TODO: this would be better in a CIDR6Tree +func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool { for _, p := range preferredRanges { if p.Contains(ip) { return true @@ -702,14 +711,3 @@ func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool { } return false } - -var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8") -var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12") -var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16") - -// isPrivateIP returns true if the ip is contained by a rfc 1918 private range -func isPrivateIP(ip net.IP) bool { - //TODO: another great cidrtree option - //TODO: Private for ipv6 or just let it ride? - return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip) -} diff --git a/remote_list_test.go b/remote_list_test.go index 49aa171..0caf86a 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -1,47 +1,57 @@ package nebula import ( - "net" + "encoding/binary" + "net/netip" "testing" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" ) func TestRemoteList_Rebuild(t *testing.T) { - rl := NewRemoteList(nil) + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( - 0, - 0, - []*Ip4AndPort{ - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), + []*V4AddrPort{ + newIp4AndPortFromString("70.199.182.92:1475"), // this is duped + newIp4AndPortFromString("172.17.0.182:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is duped + newIp4AndPortFromString("172.18.0.1:10101"), // this is duped + newIp4AndPortFromString("172.18.0.1:10101"), // this is a dupe + newIp4AndPortFromString("172.19.0.1:10101"), + newIp4AndPortFromString("172.31.0.1:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe + newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port + newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( - 1, - 1, - []*Ip6AndPort{ - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped - NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped - NewIp6AndPort(net.ParseIP("1:100::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe - NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe + netip.MustParseAddr("0.0.0.1"), + netip.MustParseAddr("0.0.0.1"), + []*V6AddrPort{ + newIp6AndPortFromString("[1::1]:1"), // this is duped + newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped + newIp6AndPortFromString("[1:100::1]:1"), + newIp6AndPortFromString("[1::1]:1"), // this is a dupe + newIp6AndPortFromString("[1::1]:2"), // this is a dupe }, - func(iputil.VpnIp, *Ip6AndPort) bool { return true }, + func(netip.Addr, *V6AddrPort) bool { return true }, ) - rl.Rebuild([]*net.IPNet{}) + rl.unlockedSetRelay( + netip.MustParseAddr("0.0.0.1"), + []netip.Addr{ + netip.MustParseAddr("1::1"), + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("1::1"), + }, + ) + + rl.Rebuild([]netip.Prefix{}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // ipv6 first, sorted lexically within @@ -59,9 +69,7 @@ func TestRemoteList_Rebuild(t *testing.T) { assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String()) // Now ensure we can hoist ipv4 up - _, ipNet, err := net.ParseCIDR("0.0.0.0/0") - assert.NoError(t, err) - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // ipv4 first, public then private, lexically within them @@ -78,10 +86,13 @@ func TestRemoteList_Rebuild(t *testing.T) { assert.Equal(t, "[1::1]:2", rl.addrs[8].String()) assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String()) + // assert relay deduplicated + assert.Len(t, rl.relays, 2) + assert.Equal(t, "1.2.3.4", rl.relays[0].String()) + assert.Equal(t, "1::1", rl.relays[1].String()) + // Ensure we can hoist a specific ipv4 range over anything else - _, ipNet, err = net.ParseCIDR("172.17.0.0/16") - assert.NoError(t, err) - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // Preferred ipv4 first @@ -102,134 +113,147 @@ func TestRemoteList_Rebuild(t *testing.T) { } func BenchmarkFullRebuild(b *testing.B) { - rl := NewRemoteList(nil) + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( - 0, - 0, - []*Ip4AndPort{ - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), + []*V4AddrPort{ + newIp4AndPortFromString("70.199.182.92:1475"), + newIp4AndPortFromString("172.17.0.182:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), + newIp4AndPortFromString("172.18.0.1:10101"), + newIp4AndPortFromString("172.19.0.1:10101"), + newIp4AndPortFromString("172.31.0.1:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe + newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( - 0, - 0, - []*Ip6AndPort{ - NewIp6AndPort(net.ParseIP("1::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port - NewIp6AndPort(net.ParseIP("1:100::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), + []*V6AddrPort{ + newIp6AndPortFromString("[1::1]:1"), + newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port + newIp6AndPortFromString("[1:100::1]:1"), + newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(iputil.VpnIp, *Ip6AndPort) bool { return true }, + func(netip.Addr, *V6AddrPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{}) + rl.Rebuild([]netip.Prefix{}) } }) - _, ipNet, err := net.ParseCIDR("172.17.0.0/16") - assert.NoError(b, err) + ipNet1 := netip.MustParsePrefix("172.17.0.0/16") b.Run("1 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{ipNet1}) } }) - _, ipNet2, err := net.ParseCIDR("70.0.0.0/8") - assert.NoError(b, err) + ipNet2 := netip.MustParsePrefix("70.0.0.0/8") b.Run("2 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + rl.Rebuild([]netip.Prefix{ipNet2}) } }) - _, ipNet3, err := net.ParseCIDR("0.0.0.0/0") - assert.NoError(b, err) + ipNet3 := netip.MustParsePrefix("0.0.0.0/0") b.Run("3 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) } }) } func BenchmarkSortRebuild(b *testing.B) { - rl := NewRemoteList(nil) + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( - 0, - 0, - []*Ip4AndPort{ - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), + []*V4AddrPort{ + newIp4AndPortFromString("70.199.182.92:1475"), + newIp4AndPortFromString("172.17.0.182:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), + newIp4AndPortFromString("172.18.0.1:10101"), + newIp4AndPortFromString("172.19.0.1:10101"), + newIp4AndPortFromString("172.31.0.1:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe + newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( - 0, - 0, - []*Ip6AndPort{ - NewIp6AndPort(net.ParseIP("1::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port - NewIp6AndPort(net.ParseIP("1:100::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), + []*V6AddrPort{ + newIp6AndPortFromString("[1::1]:1"), + newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port + newIp6AndPortFromString("[1:100::1]:1"), + newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(iputil.VpnIp, *Ip6AndPort) bool { return true }, + func(netip.Addr, *V6AddrPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{}) + rl.Rebuild([]netip.Prefix{}) } }) - _, ipNet, err := net.ParseCIDR("172.17.0.0/16") - rl.Rebuild([]*net.IPNet{ipNet}) + ipNet1 := netip.MustParsePrefix("172.17.0.0/16") + rl.Rebuild([]netip.Prefix{ipNet1}) - assert.NoError(b, err) b.Run("1 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{ipNet1}) } }) - _, ipNet2, err := net.ParseCIDR("70.0.0.0/8") - rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + ipNet2 := netip.MustParsePrefix("70.0.0.0/8") + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2}) - assert.NoError(b, err) b.Run("2 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { - rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2}) } }) - _, ipNet3, err := net.ParseCIDR("0.0.0.0/0") - rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + ipNet3 := netip.MustParsePrefix("0.0.0.0/0") + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) - assert.NoError(b, err) b.Run("3 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { - rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) } }) } + +func newIp4AndPortFromString(s string) *V4AddrPort { + a := netip.MustParseAddrPort(s) + v4Addr := a.Addr().As4() + return &V4AddrPort{ + Addr: binary.BigEndian.Uint32(v4Addr[:]), + Port: uint32(a.Port()), + } +} + +func newIp6AndPortFromString(s string) *V6AddrPort { + a := netip.MustParseAddrPort(s) + v6Addr := a.Addr().As16() + return &V6AddrPort{ + Hi: binary.BigEndian.Uint64(v6Addr[:8]), + Lo: binary.BigEndian.Uint64(v6Addr[8:]), + Port: uint32(a.Port()), + } +} diff --git a/routing/balance.go b/routing/balance.go new file mode 100644 index 0000000..6f52497 --- /dev/null +++ b/routing/balance.go @@ -0,0 +1,39 @@ +package routing + +import ( + "net/netip" + + "github.com/slackhq/nebula/firewall" +) + +// Hashes the packet source and destination port and always returns a positive integer +// Based on 'Prospecting for Hash Functions' +// - https://nullprogram.com/blog/2018/07/31/ +// - https://github.com/skeeto/hash-prospector +// [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501 +func hashPacket(p *firewall.Packet) int { + x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort) + x ^= x >> 16 + x *= 0x21f0aaad + x ^= x >> 15 + x *= 0xd35a2d97 + x ^= x >> 15 + + return int(x) & 0x7FFFFFFF +} + +// For this function to work correctly it requires that the buckets for the gateways have been calculated +// If the contract is violated balancing will not work properly and the second return value will return false +func BalancePacket(fwPacket *firewall.Packet, gateways []Gateway) (netip.Addr, bool) { + hash := hashPacket(fwPacket) + + for i := range gateways { + if hash <= gateways[i].BucketUpperBound() { + return gateways[i].Addr(), true + } + } + + // If you land here then the buckets for the gateways are not properly calculated + // Fallback to random routing and let the caller know + return gateways[hash%len(gateways)].Addr(), false +} diff --git a/routing/balance_test.go b/routing/balance_test.go new file mode 100644 index 0000000..bbfcb22 --- /dev/null +++ b/routing/balance_test.go @@ -0,0 +1,144 @@ +package routing + +import ( + "net/netip" + "testing" + + "github.com/slackhq/nebula/firewall" + "github.com/stretchr/testify/assert" +) + +func TestPacketsAreBalancedEqually(t *testing.T) { + + gateways := []Gateway{} + + gw1Addr := netip.MustParseAddr("1.0.0.1") + gw2Addr := netip.MustParseAddr("1.0.0.2") + gw3Addr := netip.MustParseAddr("1.0.0.3") + + gateways = append(gateways, NewGateway(gw1Addr, 1)) + gateways = append(gateways, NewGateway(gw2Addr, 1)) + gateways = append(gateways, NewGateway(gw3Addr, 1)) + + CalculateBucketsForGateways(gateways) + + gw1count := 0 + gw2count := 0 + gw3count := 0 + + iterationCount := uint16(65535) + for i := uint16(0); i < iterationCount; i++ { + packet := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: i, + RemotePort: 65535 - i, + Protocol: 6, // TCP + Fragment: false, + } + + selectedGw, ok := BalancePacket(&packet, gateways) + assert.True(t, ok) + + switch selectedGw { + case gw1Addr: + gw1count += 1 + case gw2Addr: + gw2count += 1 + case gw3Addr: + gw3count += 1 + } + + } + + // Assert packets are balanced, allow variation of up to 100 packets per gateway + assert.InDeltaf(t, iterationCount/3, gw1count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count) + assert.InDeltaf(t, iterationCount/3, gw2count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count) + assert.InDeltaf(t, iterationCount/3, gw3count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count) + +} + +func TestPacketsAreBalancedByPriority(t *testing.T) { + + gateways := []Gateway{} + + gw1Addr := netip.MustParseAddr("1.0.0.1") + gw2Addr := netip.MustParseAddr("1.0.0.2") + + gateways = append(gateways, NewGateway(gw1Addr, 10)) + gateways = append(gateways, NewGateway(gw2Addr, 5)) + + CalculateBucketsForGateways(gateways) + + gw1count := 0 + gw2count := 0 + + iterationCount := uint16(65535) + for i := uint16(0); i < iterationCount; i++ { + packet := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: i, + RemotePort: 65535 - i, + Protocol: 6, // TCP + Fragment: false, + } + + selectedGw, ok := BalancePacket(&packet, gateways) + assert.True(t, ok) + + switch selectedGw { + case gw1Addr: + gw1count += 1 + case gw2Addr: + gw2count += 1 + } + + } + + iterationCountAsFloat := float32(iterationCount) + + assert.InDeltaf(t, iterationCountAsFloat*(2.0/3.0), gw1count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(2.0/3.0), gw1count) + assert.InDeltaf(t, iterationCountAsFloat*(1.0/3.0), gw2count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(1.0/3.0), gw2count) +} + +func TestBalancePacketDistributsRandomlyAndReturnsFalseIfBucketsNotCalculated(t *testing.T) { + gateways := []Gateway{} + + gw1Addr := netip.MustParseAddr("1.0.0.1") + gw2Addr := netip.MustParseAddr("1.0.0.2") + + gateways = append(gateways, NewGateway(gw1Addr, 10)) + gateways = append(gateways, NewGateway(gw2Addr, 5)) + + iterationCount := uint16(65535) + gw1count := 0 + gw2count := 0 + + for i := uint16(0); i < iterationCount; i++ { + packet := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: i, + RemotePort: 65535 - i, + Protocol: 6, // TCP + Fragment: false, + } + + selectedGw, ok := BalancePacket(&packet, gateways) + assert.False(t, ok) + + switch selectedGw { + case gw1Addr: + gw1count += 1 + case gw2Addr: + gw2count += 1 + } + + } + + assert.Equal(t, int(iterationCount), (gw1count + gw2count)) + assert.NotEqual(t, 0, gw1count) + assert.NotEqual(t, 0, gw2count) + +} diff --git a/routing/gateway.go b/routing/gateway.go new file mode 100644 index 0000000..59d38a9 --- /dev/null +++ b/routing/gateway.go @@ -0,0 +1,70 @@ +package routing + +import ( + "fmt" + "net/netip" +) + +const ( + // Sentinal value + BucketNotCalculated = -1 +) + +type Gateways []Gateway + +func (g Gateways) String() string { + str := "" + for i, gw := range g { + str += gw.String() + if i < len(g)-1 { + str += ", " + } + } + return str +} + +type Gateway struct { + addr netip.Addr + weight int + bucketUpperBound int +} + +func NewGateway(addr netip.Addr, weight int) Gateway { + return Gateway{addr: addr, weight: weight, bucketUpperBound: BucketNotCalculated} +} + +func (g *Gateway) BucketUpperBound() int { + return g.bucketUpperBound +} + +func (g *Gateway) Addr() netip.Addr { + return g.addr +} + +func (g *Gateway) String() string { + return fmt.Sprintf("{addr: %s, weight: %d}", g.addr, g.weight) +} + +// Divide and round to nearest integer +func divideAndRound(v uint64, d uint64) uint64 { + var tmp uint64 = v + d/2 + return tmp / d +} + +// Implements Hash-Threshold mapping, equivalent to the implementation in the linux kernel. +// After this function returns each gateway will have a +// positive bucketUpperBound with a maximum value of 2147483647 (INT_MAX) +func CalculateBucketsForGateways(gateways []Gateway) { + + var totalWeight int = 0 + for i := range gateways { + totalWeight += gateways[i].weight + } + + var loopWeight int = 0 + for i := range gateways { + loopWeight += gateways[i].weight + gateways[i].bucketUpperBound = int(divideAndRound(uint64(loopWeight)<<31, uint64(totalWeight))) - 1 + } + +} diff --git a/routing/gateway_test.go b/routing/gateway_test.go new file mode 100644 index 0000000..8ae78f3 --- /dev/null +++ b/routing/gateway_test.go @@ -0,0 +1,34 @@ +package routing + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRebalance3_2Split(t *testing.T) { + gateways := []Gateway{} + + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 10}) + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 5}) + + CalculateBucketsForGateways(gateways) + + assert.Equal(t, 1431655764, gateways[0].bucketUpperBound) // INT_MAX/3*2 + assert.Equal(t, 2147483647, gateways[1].bucketUpperBound) // INT_MAX +} + +func TestRebalanceEqualSplit(t *testing.T) { + gateways := []Gateway{} + + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1}) + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1}) + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1}) + + CalculateBucketsForGateways(gateways) + + assert.Equal(t, 715827882, gateways[0].bucketUpperBound) // INT_MAX/3 + assert.Equal(t, 1431655764, gateways[1].bucketUpperBound) // INT_MAX/3*2 + assert.Equal(t, 2147483647, gateways[2].bucketUpperBound) // INT_MAX +} diff --git a/service/service.go b/service/service.go index 6816be6..4339677 100644 --- a/service/service.go +++ b/service/service.go @@ -8,6 +8,7 @@ import ( "log" "math" "net" + "net/netip" "os" "strings" "sync" @@ -89,9 +90,9 @@ func New(config *config.C) (*Service, error) { }, }) - ipNet := device.Cidr() + ipNet := device.Networks() pa := tcpip.ProtocolAddress{ - AddressWithPrefix: tcpip.AddrFromSlice(ipNet.IP).WithPrefix(), + AddressWithPrefix: tcpip.AddrFromSlice(ipNet[0].Addr().AsSlice()).WithPrefix(), Protocol: ipv4.ProtocolNumber, } if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{ @@ -153,24 +154,48 @@ func New(config *config.C) (*Service, error) { return &s, nil } -// DialContext dials the provided address. Currently only TCP is supported. +func getProtocolNumber(addr netip.Addr) tcpip.NetworkProtocolNumber { + if addr.Is6() { + return ipv6.ProtocolNumber + } + return ipv4.ProtocolNumber +} + +// DialContext dials the provided address. func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - if network != "tcp" && network != "tcp4" { - return nil, errors.New("only tcp is supported") + switch network { + case "udp", "udp4", "udp6": + addr, err := net.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + fullAddr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(addr.IP), + Port: uint16(addr.Port), + } + num := getProtocolNumber(addr.AddrPort().Addr()) + return gonet.DialUDP(s.ipstack, nil, &fullAddr, num) + case "tcp", "tcp4", "tcp6": + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + fullAddr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(addr.IP), + Port: uint16(addr.Port), + } + num := getProtocolNumber(addr.AddrPort().Addr()) + return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, num) + default: + return nil, fmt.Errorf("unknown network type: %s", network) } +} - addr, err := net.ResolveTCPAddr(network, address) - if err != nil { - return nil, err - } - - fullAddr := tcpip.FullAddress{ - NIC: nicID, - Addr: tcpip.AddrFromSlice(addr.IP), - Port: uint16(addr.Port), - } - - return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber) +// Dial dials the provided address +func (s *Service) Dial(network, address string) (net.Conn, error) { + return s.DialContext(context.Background(), network, address) } // Listen listens on the provided address. Currently only TCP with wildcard diff --git a/service/service_test.go b/service/service_test.go index d1909cd..b9810cd 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -4,27 +4,23 @@ import ( "bytes" "context" "errors" - "net" + "net/netip" "testing" "time" "dario.cat/mergo" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/e2e" "golang.org/x/sync/errgroup" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) -type m map[string]interface{} +type m = map[string]any -func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service { - - vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} - copy(vpnIpNet.IP, udpIp) - - _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) - caB, err := caCrt.MarshalToPEM() +func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { + _, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}) + caB, err := caCrt.MarshalPEM() if err != nil { panic(err) } @@ -83,8 +79,8 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, } func TestService(t *testing.T) { - ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{ + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ "static_host_map": m{}, "lighthouse": m{ "am_lighthouse": true, @@ -94,7 +90,7 @@ func TestService(t *testing.T) { "port": 4243, }, }) - b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{ + b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{ "static_host_map": m{ "10.0.0.1": []string{"localhost:4243"}, }, diff --git a/ssh.go b/ssh.go index f096121..9a26c29 100644 --- a/ssh.go +++ b/ssh.go @@ -7,6 +7,7 @@ import ( "flag" "fmt" "net" + "net/netip" "os" "reflect" "runtime" @@ -18,9 +19,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/sshd" - "github.com/slackhq/nebula/udp" ) type sshListHostMapFlags struct { @@ -78,9 +77,6 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) { // that callers may invoke to run the configured ssh server. On // failure, it returns nil, error. func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) { - //TODO conntrack list - //TODO print firewall rules or hash? - listen := c.GetString("sshd.listen", "") if listen == "" { return nil, fmt.Errorf("sshd.listen must be provided") @@ -94,7 +90,6 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro return nil, fmt.Errorf("sshd.listen can not use port 22") } - //TODO: no good way to reload this right now hostKeyPathOrKey := c.GetString("sshd.host_key", "") if hostKeyPathOrKey == "" { return nil, fmt.Errorf("sshd.host_key must be provided") @@ -129,10 +124,10 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro } rawKeys := c.Get("sshd.authorized_users") - keys, ok := rawKeys.([]interface{}) + keys, ok := rawKeys.([]any) if ok { for _, rk := range keys { - kDef, ok := rk.(map[interface{}]interface{}) + kDef, ok := rk.(map[string]any) if !ok { l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring") continue @@ -153,7 +148,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro continue } - case []interface{}: + case []any: for _, subK := range v { sk, ok := subK.(string) if !ok { @@ -195,7 +190,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "list-hostmap", ShortDescription: "List all known previously connected hosts", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshListHostMapFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") @@ -203,7 +198,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshListHostMap(f.hostMap, fs, w) }, }) @@ -211,7 +206,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "list-pending-hostmap", ShortDescription: "List all handshaking hosts", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshListHostMapFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") @@ -219,7 +214,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshListHostMap(f.handshakeManager, fs, w) }, }) @@ -227,14 +222,14 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "list-lighthouse-addrmap", ShortDescription: "List all lighthouse map entries", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshListHostMapFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshListLighthouseMap(f.lightHouse, fs, w) }, }) @@ -242,7 +237,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "reload", ShortDescription: "Reloads configuration from disk, same as sending HUP to the process", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshReload(c, w) }, }) @@ -256,7 +251,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "stop-cpu-profile", ShortDescription: "Stops a cpu profile and writes output to the previously provided file", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { pprof.StopCPUProfile() return w.WriteLine("If a CPU profile was running it is now stopped") }, @@ -283,7 +278,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "log-level", ShortDescription: "Gets or sets the current log level", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshLogLevel(l, fs, a, w) }, }) @@ -291,7 +286,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "log-format", ShortDescription: "Gets or sets the current log format", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshLogFormat(l, fs, a, w) }, }) @@ -299,7 +294,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "version", ShortDescription: "Prints the currently running version of nebula", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshVersion(f, fs, a, w) }, }) @@ -307,22 +302,22 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "device-info", ShortDescription: "Prints information about the network device.", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshDeviceInfoFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshDeviceInfo(f, fs, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "print-cert", - ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn ip", - Flags: func() (*flag.FlagSet, interface{}) { + ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr", + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintCertFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json") @@ -330,21 +325,21 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter fl.BoolVar(&s.Raw, "raw", false, "raw prints the PEM encoded certificate, not compatible with -json or -pretty") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshPrintCert(f, fs, a, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "print-tunnel", - ShortDescription: "Prints json details about a tunnel for the provided vpn ip", - Flags: func() (*flag.FlagSet, interface{}) { + ShortDescription: "Prints json details about a tunnel for the provided vpn addr", + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintTunnelFlags{} fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshPrintTunnel(f, fs, a, w) }, }) @@ -352,74 +347,73 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "print-relays", ShortDescription: "Prints json details about all relay info", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintTunnelFlags{} fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshPrintRelays(f, fs, a, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "change-remote", - ShortDescription: "Changes the remote address used in the tunnel for the provided vpn ip", - Flags: func() (*flag.FlagSet, interface{}) { + ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr", + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshChangeRemoteFlags{} fl.StringVar(&s.Address, "address", "", "The new remote address, ip:port") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshChangeRemote(f, fs, a, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "close-tunnel", - ShortDescription: "Closes a tunnel for the provided vpn ip", - Flags: func() (*flag.FlagSet, interface{}) { + ShortDescription: "Closes a tunnel for the provided vpn addr", + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshCloseTunnelFlags{} fl.BoolVar(&s.LocalOnly, "local-only", false, "Disables notifying the remote that the tunnel is shutting down") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshCloseTunnel(f, fs, a, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "create-tunnel", - ShortDescription: "Creates a tunnel for the provided vpn ip and address", + ShortDescription: "Creates a tunnel for the provided vpn address", Help: "The lighthouses will be queried for real addresses but you can provide one as well.", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshCreateTunnelFlags{} fl.StringVar(&s.Address, "address", "", "Optionally provide a real remote address, ip:port ") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshCreateTunnel(f, fs, a, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "query-lighthouse", - ShortDescription: "Query the lighthouses for the provided vpn ip", - Help: "This command is asynchronous. Only currently known udp ips will be printed.", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + ShortDescription: "Query the lighthouses for the provided vpn address", + Help: "This command is asynchronous. Only currently known udp addresses will be printed.", + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshQueryLighthouse(f, fs, a, w) }, }) } -func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error { +func sshListHostMap(hl controlHostLister, a any, w sshd.StringWriter) error { fs, ok := a.(*sshListHostMapFlags) if !ok { - //TODO: error return nil } @@ -431,7 +425,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er } sort.Slice(hm, func(i, j int) bool { - return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0 + return hm[i].VpnAddrs[0].Compare(hm[j].VpnAddrs[0]) < 0 }) if fs.Json || fs.Pretty { @@ -442,13 +436,12 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er err := js.Encode(hm) if err != nil { - //TODO return nil } } else { for _, v := range hm { - err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs)) + err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddrs, v.RemoteAddrs)) if err != nil { return err } @@ -458,16 +451,15 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er return nil } -func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWriter) error { +func sshListLighthouseMap(lightHouse *LightHouse, a any, w sshd.StringWriter) error { fs, ok := a.(*sshListHostMapFlags) if !ok { - //TODO: error return nil } type lighthouseInfo struct { - VpnIp string `json:"vpnIp"` - Addrs *CacheMap `json:"addrs"` + VpnAddr string `json:"vpnAddr"` + Addrs *CacheMap `json:"addrs"` } lightHouse.RLock() @@ -475,15 +467,15 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr x := 0 for k, v := range lightHouse.addrMap { addrMap[x] = lighthouseInfo{ - VpnIp: k.String(), - Addrs: v.CopyCache(), + VpnAddr: k.String(), + Addrs: v.CopyCache(), } x++ } lightHouse.RUnlock() sort.Slice(addrMap, func(i, j int) bool { - return strings.Compare(addrMap[i].VpnIp, addrMap[j].VpnIp) < 0 + return strings.Compare(addrMap[i].VpnAddr, addrMap[j].VpnAddr) < 0 }) if fs.Json || fs.Pretty { @@ -494,7 +486,6 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr err := js.Encode(addrMap) if err != nil { - //TODO return nil } @@ -504,7 +495,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr if err != nil { return err } - err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, string(b))) + err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddr, string(b))) if err != nil { return err } @@ -514,7 +505,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr return nil } -func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error { +func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { err := w.WriteLine("No path to write profile provided") return err @@ -536,57 +527,54 @@ func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error { return err } -func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshVersion(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { return w.WriteLine(fmt.Sprintf("%s", ifce.version)) } -func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshQueryLighthouse(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + vpnAddr, err := netip.ParseAddr(a[0]) + if err != nil { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } var cm *CacheMap - rl := ifce.lightHouse.Query(vpnIp) + rl := ifce.lightHouse.Query(vpnAddr) if rl != nil { cm = rl.CopyCache() } return json.NewEncoder(w.GetWriter()).Encode(cm) } -func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshCloseTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshCloseTunnelFlags) if !ok { - //TODO: error return nil } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + vpnAddr, err := netip.ParseAddr(a[0]) + if err != nil { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0])) } if !flags.LocalOnly { @@ -605,93 +593,89 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine("Closed") } -func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshCreateTunnelFlags) if !ok { - //TODO: error return nil } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + vpnAddr, err := netip.ParseAddr(a[0]) + if err != nil { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already exists")) } - hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp) + hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnAddr) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) } - var addr *udp.Addr + var addr netip.AddrPort if flags.Address != "" { - addr = udp.NewAddrFromString(flags.Address) - if addr == nil { + addr, err = netip.ParseAddrPort(flags.Address) + if err != nil { return w.WriteLine("Address could not be parsed") } } - hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil) - if addr != nil { + hostInfo = ifce.handshakeManager.StartHandshake(vpnAddr, nil) + if addr.IsValid() { hostInfo.SetRemote(addr) } return w.WriteLine("Created") } -func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshChangeRemote(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshChangeRemoteFlags) if !ok { - //TODO: error return nil } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } if flags.Address == "" { return w.WriteLine("No address was provided") } - addr := udp.NewAddrFromString(flags.Address) - if addr == nil { + addr, err := netip.ParseAddrPort(flags.Address) + if err != nil { return w.WriteLine("Address could not be parsed") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + vpnAddr, err := netip.ParseAddr(a[0]) + if err != nil { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0])) } hostInfo.SetRemote(addr) return w.WriteLine("Changed") } -func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error { +func sshGetHeapProfile(fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No path to write profile provided") } @@ -712,7 +696,7 @@ func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error { return err } -func sshMutexProfileFraction(fs interface{}, a []string, w sshd.StringWriter) error { +func sshMutexProfileFraction(fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { rate := runtime.SetMutexProfileFraction(-1) return w.WriteLine(fmt.Sprintf("Current value: %d", rate)) @@ -727,7 +711,7 @@ func sshMutexProfileFraction(fs interface{}, a []string, w sshd.StringWriter) er return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate)) } -func sshGetMutexProfile(fs interface{}, a []string, w sshd.StringWriter) error { +func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No path to write profile provided") } @@ -751,7 +735,7 @@ func sshGetMutexProfile(fs interface{}, a []string, w sshd.StringWriter) error { return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a)) } -func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error { +func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) } @@ -765,7 +749,7 @@ func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWrit return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) } -func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error { +func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) } @@ -783,37 +767,34 @@ func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWri return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) } -func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintCertFlags) if !ok { - //TODO: error return nil } - cert := ifce.pki.GetCertState().Certificate + cert := ifce.pki.getCertState().GetDefaultCertificate() if len(a) > 0 { - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + vpnAddr, err := netip.ParseAddr(a[0]) + if err != nil { + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0])) } - cert = hostInfo.GetCert() + cert = hostInfo.GetCert().Certificate } if args.Json || args.Pretty { b, err := cert.MarshalJSON() if err != nil { - //TODO: handle it return nil } @@ -822,7 +803,6 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit err := json.Indent(buf, b, "", " ") b = buf.Bytes() if err != nil { - //TODO: handle it return nil } } @@ -831,9 +811,8 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit } if args.Raw { - b, err := cert.MarshalToPEM() + b, err := cert.MarshalPEM() if err != nil { - //TODO: handle it return nil } @@ -843,10 +822,9 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return w.WriteLine(cert.String()) } -func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintTunnelFlags) if !ok { - //TODO: error w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type")) return nil } @@ -862,15 +840,15 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr Error error Type string State string - PeerIp iputil.VpnIp + PeerAddr netip.Addr LocalIndex uint32 RemoteIndex uint32 - RelayedThrough []iputil.VpnIp + RelayedThrough []netip.Addr } type RelayOutput struct { - NebulaIp iputil.VpnIp - RelayForIps []RelayFor + NebulaAddr netip.Addr + RelayForAddrs []RelayFor } type CmdOutput struct { @@ -886,16 +864,16 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr } for k, v := range relays { - ro := RelayOutput{NebulaIp: v.vpnIp} + ro := RelayOutput{NebulaAddr: v.vpnAddrs[0]} co.Relays = append(co.Relays, &ro) - relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp) + relayHI := ifce.hostMap.QueryVpnAddr(v.vpnAddrs[0]) if relayHI == nil { - ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")}) + ro.RelayForAddrs = append(ro.RelayForAddrs, RelayFor{Error: errors.New("could not find hostinfo")}) continue } - for _, vpnIp := range relayHI.relayState.CopyRelayForIps() { + for _, vpnAddr := range relayHI.relayState.CopyRelayForIps() { rf := RelayFor{Error: nil} - r, ok := relayHI.relayState.GetRelayForByIp(vpnIp) + r, ok := relayHI.relayState.GetRelayForByAddr(vpnAddr) if ok { t := "" switch r.Type { @@ -919,19 +897,19 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr rf.LocalIndex = r.LocalIndex rf.RemoteIndex = r.RemoteIndex - rf.PeerIp = r.PeerIp + rf.PeerAddr = r.PeerAddr rf.Type = t rf.State = s if rf.LocalIndex != k { rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k) } } - relayedHI := ifce.hostMap.QueryVpnIp(vpnIp) + relayedHI := ifce.hostMap.QueryVpnAddr(vpnAddr) if relayedHI != nil { rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...) } - ro.RelayForIps = append(ro.RelayForIps, rf) + ro.RelayForAddrs = append(ro.RelayForAddrs, rf) } } err := enc.Encode(co) @@ -941,30 +919,28 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return nil } -func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshPrintTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintTunnelFlags) if !ok { - //TODO: error return nil } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + vpnAddr, err := netip.ParseAddr(a[0]) + if err != nil { + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0])) } enc := json.NewEncoder(w.GetWriter()) @@ -975,16 +951,18 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges())) } -func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error { +func sshDeviceInfo(ifce *Interface, fs any, w sshd.StringWriter) error { data := struct { - Name string `json:"name"` - Cidr string `json:"cidr"` + Name string `json:"name"` + Cidr []netip.Prefix `json:"cidr"` }{ Name: ifce.inside.Name(), - Cidr: ifce.inside.Cidr().String(), + Cidr: make([]netip.Prefix, len(ifce.inside.Networks())), } + copy(data.Cidr, ifce.inside.Networks()) + flags, ok := fs.(*sshDeviceInfoFlags) if !ok { return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs) diff --git a/sshd/command.go b/sshd/command.go index 900b01e..7323d12 100644 --- a/sshd/command.go +++ b/sshd/command.go @@ -12,7 +12,7 @@ import ( // CommandFlags is a function called before help or command execution to parse command line flags // It should return a flag.FlagSet instance and a pointer to the struct that will contain parsed flags -type CommandFlags func() (*flag.FlagSet, interface{}) +type CommandFlags func() (*flag.FlagSet, any) // CommandCallback is the function called when your command should execute. // fs will be a a pointer to the struct provided by Command.Flags callback, if there was one. -h and -help are reserved @@ -21,7 +21,7 @@ type CommandFlags func() (*flag.FlagSet, interface{}) // w is the writer to use when sending messages back to the client. // If an error is returned by the callback it is logged locally, the callback should handle messaging errors to the user // where appropriate -type CommandCallback func(fs interface{}, a []string, w StringWriter) error +type CommandCallback func(fs any, a []string, w StringWriter) error type Command struct { Name string @@ -34,7 +34,7 @@ type Command struct { func execCommand(c *Command, args []string, w StringWriter) error { var ( fl *flag.FlagSet - fs interface{} + fs any ) if c.Flags != nil { @@ -57,7 +57,6 @@ func execCommand(c *Command, args []string, w StringWriter) error { func dumpCommands(c *radix.Tree, w StringWriter) { err := w.WriteLine("Available commands:") if err != nil { - //TODO: log return } @@ -67,10 +66,7 @@ func dumpCommands(c *radix.Tree, w StringWriter) { } sort.Strings(cmds) - err = w.Write(strings.Join(cmds, "\n") + "\n\n") - if err != nil { - //TODO: log - } + _ = w.Write(strings.Join(cmds, "\n") + "\n\n") } func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) { @@ -89,7 +85,7 @@ func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) { func matchCommand(c *radix.Tree, cmd string) []string { cmds := make([]string, 0) - c.WalkPrefix(cmd, func(found string, v interface{}) bool { + c.WalkPrefix(cmd, func(found string, v any) bool { cmds = append(cmds, found) return false }) @@ -99,7 +95,7 @@ func matchCommand(c *radix.Tree, cmd string) []string { func allCommands(c *radix.Tree) []*Command { cmds := make([]*Command, 0) - c.WalkPrefix("", func(found string, v interface{}) bool { + c.WalkPrefix("", func(found string, v any) bool { cmd, ok := v.(*Command) if ok { cmds = append(cmds, cmd) @@ -119,8 +115,6 @@ func helpCallback(commands *radix.Tree, a []string, w StringWriter) (err error) // We are printing a specific commands help text cmd, err := lookupCommand(commands, a[0]) if err != nil { - //TODO: handle error - //TODO: message the user return } diff --git a/sshd/server.go b/sshd/server.go index 9e8c721..a8b60ba 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -80,15 +80,13 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { s.config = &ssh.ServerConfig{ PublicKeyCallback: cc.Authenticate, - //TODO: AuthLogCallback: s.authAttempt, - //TODO: version string - ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"), + ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"), } s.RegisterCommand(&Command{ Name: "help", ShortDescription: "prints available commands or help for specific usage info", - Callback: func(a interface{}, args []string, w StringWriter) error { + Callback: func(a any, args []string, w StringWriter) error { return helpCallback(s.commands, args, w) }, }) diff --git a/sshd/session.go b/sshd/session.go index bba2a55..87cc216 100644 --- a/sshd/session.go +++ b/sshd/session.go @@ -9,13 +9,13 @@ import ( "github.com/armon/go-radix" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/terminal" + "golang.org/x/term" ) type session struct { l *logrus.Entry c *ssh.ServerConn - term *terminal.Terminal + term *term.Terminal commands *radix.Tree exitChan chan bool } @@ -31,7 +31,7 @@ func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.New s.commands.Insert("logout", &Command{ Name: "logout", ShortDescription: "Ends the current session", - Callback: func(a interface{}, args []string, w StringWriter) error { + Callback: func(a any, args []string, w StringWriter) error { s.Close() return nil }, @@ -62,7 +62,6 @@ func (s *session) handleChannels(chans <-chan ssh.NewChannel) { func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { for req := range in { var err error - //TODO: maybe support window sizing? switch req.Type { case "shell": if s.term == nil { @@ -89,9 +88,7 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { req.Reply(true, nil) s.dispatchCommand(payload.Value, &stringWriter{channel}) - //TODO: Fix error handling and report the proper status back status := struct{ Status uint32 }{uint32(0)} - //TODO: I think this is how we shut down a shell as well? channel.SendRequest("exit-status", false, ssh.Marshal(status)) channel.Close() return @@ -109,9 +106,8 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { } } -func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal { - //TODO: PS1 with nebula cert name - term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ") +func (s *session) createTerm(channel ssh.Channel) *term.Terminal { + term := term.NewTerminal(channel, s.c.User()+"@nebula > ") term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) { // key 9 is tab if key == 9 { @@ -137,7 +133,6 @@ func (s *session) handleInput(channel ssh.Channel) { for { line, err := s.term.ReadLine() if err != nil { - //TODO: log break } @@ -148,7 +143,6 @@ func (s *session) handleInput(channel ssh.Channel) { func (s *session) dispatchCommand(line string, w StringWriter) { args, err := shlex.Split(line, true) if err != nil { - //todo: LOG IT return } @@ -159,13 +153,11 @@ func (s *session) dispatchCommand(line string, w StringWriter) { c, err := lookupCommand(s.commands, args[0]) if err != nil { - //TODO: handle the error return } if c == nil { err := w.WriteLine(fmt.Sprintf("did not understand: %s", line)) - //TODO: log error _ = err dumpCommands(s.commands, w) @@ -177,10 +169,7 @@ func (s *session) dispatchCommand(line string, w StringWriter) { return } - err = execCommand(c, args[1:], w) - if err != nil { - //TODO: log the error - } + _ = execCommand(c, args[1:], w) return } diff --git a/test/assert.go b/test/assert.go index 6c6c795..1856877 100644 --- a/test/assert.go +++ b/test/assert.go @@ -2,6 +2,7 @@ package test import ( "fmt" + "net/netip" "reflect" "testing" "time" @@ -12,7 +13,7 @@ import ( // AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory // There is currently a special case for `time.loc` (as this code traverses into unexported fields) -func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) { +func AssertDeepCopyEqual(t *testing.T, a any, b any) { v1 := reflect.ValueOf(a) v2 := reflect.ValueOf(b) @@ -24,6 +25,11 @@ func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) { } func traverseDeepCopy(t *testing.T, v1 reflect.Value, v2 reflect.Value, name string) bool { + if v1.Type() == v2.Type() && v1.Type() == reflect.TypeOf(netip.Addr{}) { + // Ignore netip.Addr types since they reuse an interned global value + return false + } + switch v1.Kind() { case reflect.Array: for i := 0; i < v1.Len(); i++ { diff --git a/test/tun.go b/test/tun.go index 86656c9..ca65805 100644 --- a/test/tun.go +++ b/test/tun.go @@ -3,23 +3,23 @@ package test import ( "errors" "io" - "net" + "net/netip" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" ) type NoopTun struct{} -func (NoopTun) RouteFor(iputil.VpnIp) iputil.VpnIp { - return 0 +func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways { + return routing.Gateways{} } func (NoopTun) Activate() error { return nil } -func (NoopTun) Cidr() *net.IPNet { - return nil +func (NoopTun) Networks() []netip.Prefix { + return []netip.Prefix{} } func (NoopTun) Name() string { diff --git a/timeout_test.go b/timeout_test.go index 3f81ff4..db36fec 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -1,6 +1,7 @@ package nebula import ( + "net/netip" "testing" "time" @@ -115,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) { assert.Equal(t, 0, tw.current) fps := []firewall.Packet{ - {LocalIP: 1}, - {LocalIP: 2}, - {LocalIP: 3}, - {LocalIP: 4}, + {LocalAddr: netip.MustParseAddr("0.0.0.1")}, + {LocalAddr: netip.MustParseAddr("0.0.0.2")}, + {LocalAddr: netip.MustParseAddr("0.0.0.3")}, + {LocalAddr: netip.MustParseAddr("0.0.0.4")}, } tw.Add(fps[0], time.Second*1) diff --git a/udp/conn.go b/udp/conn.go index a2c24a1..895b0df 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -1,30 +1,23 @@ package udp import ( + "net/netip" + "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" ) const MTU = 9001 type EncReader func( - addr *Addr, - out []byte, - packet []byte, - header *header.H, - fwPacket *firewall.Packet, - lhh LightHouseHandlerFunc, - nb []byte, - q int, - localCache firewall.ConntrackCache, + addr netip.AddrPort, + payload []byte, ) type Conn interface { Rebind() error - LocalAddr() (*Addr, error) - ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) - WriteTo(b []byte, addr *Addr) error + LocalAddr() (netip.AddrPort, error) + ListenOut(r EncReader) + WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) Close() error } @@ -34,13 +27,13 @@ type NoopConn struct{} func (NoopConn) Rebind() error { return nil } -func (NoopConn) LocalAddr() (*Addr, error) { - return nil, nil +func (NoopConn) LocalAddr() (netip.AddrPort, error) { + return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) { +func (NoopConn) ListenOut(_ EncReader) { return } -func (NoopConn) WriteTo(_ []byte, _ *Addr) error { +func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } func (NoopConn) ReloadConfig(_ *config.C) { diff --git a/udp/temp.go b/udp/temp.go deleted file mode 100644 index 2efe31d..0000000 --- a/udp/temp.go +++ /dev/null @@ -1,9 +0,0 @@ -package udp - -import ( - "github.com/slackhq/nebula/iputil" -) - -//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare - -type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte) diff --git a/udp/udp_all.go b/udp/udp_all.go deleted file mode 100644 index 093bf69..0000000 --- a/udp/udp_all.go +++ /dev/null @@ -1,100 +0,0 @@ -package udp - -import ( - "encoding/json" - "fmt" - "net" - "strconv" -) - -type m map[string]interface{} - -type Addr struct { - IP net.IP - Port uint16 -} - -func NewAddr(ip net.IP, port uint16) *Addr { - addr := Addr{IP: make([]byte, net.IPv6len), Port: port} - copy(addr.IP, ip.To16()) - return &addr -} - -func NewAddrFromString(s string) *Addr { - ip, port, err := ParseIPAndPort(s) - //TODO: handle err - _ = err - return &Addr{IP: ip.To16(), Port: port} -} - -func (ua *Addr) Equals(t *Addr) bool { - if t == nil || ua == nil { - return t == nil && ua == nil - } - return ua.IP.Equal(t.IP) && ua.Port == t.Port -} - -func (ua *Addr) String() string { - if ua == nil { - return "" - } - - return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port)) -} - -func (ua *Addr) MarshalJSON() ([]byte, error) { - if ua == nil { - return nil, nil - } - - return json.Marshal(m{"ip": ua.IP, "port": ua.Port}) -} - -func (ua *Addr) Copy() *Addr { - if ua == nil { - return nil - } - - nu := Addr{ - Port: ua.Port, - IP: make(net.IP, len(ua.IP)), - } - - copy(nu.IP, ua.IP) - return &nu -} - -type AddrSlice []*Addr - -func (a AddrSlice) Equal(b AddrSlice) bool { - if len(a) != len(b) { - return false - } - - for i := range a { - if !a[i].Equals(b[i]) { - return false - } - } - - return true -} - -func ParseIPAndPort(s string) (net.IP, uint16, error) { - rIp, sPort, err := net.SplitHostPort(s) - if err != nil { - return nil, 0, err - } - - addr, err := net.ResolveIPAddr("ip", rIp) - if err != nil { - return nil, 0, err - } - - iPort, err := strconv.Atoi(sPort) - if err != nil { - return nil, 0, err - } - - return addr.IP, uint16(iPort), nil -} diff --git a/udp/udp_android.go b/udp/udp_android.go index 8d69074..bb19195 100644 --- a/udp/udp_android.go +++ b/udp/udp_android.go @@ -6,13 +6,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_bsd.go b/udp/udp_bsd.go index 785aa6a..65ef31a 100644 --- a/udp/udp_bsd.go +++ b/udp/udp_bsd.go @@ -9,13 +9,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 08e1b6a..183ac7a 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -8,13 +8,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 1dd6d1d..06a4d53 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -11,11 +11,10 @@ import ( "context" "fmt" "net" + "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" ) type GenericConn struct { @@ -25,7 +24,7 @@ type GenericConn struct { var _ Conn = &GenericConn{} -func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { @@ -37,28 +36,29 @@ func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc) } -func (u *GenericConn) WriteTo(b []byte, addr *Addr) error { - _, err := u.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)}) +func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error { + _, err := u.UDPConn.WriteToUDPAddrPort(b, addr) return err } -func (u *GenericConn) LocalAddr() (*Addr, error) { +func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() switch v := a.(type) { case *net.UDPAddr: - addr := &Addr{IP: make([]byte, len(v.IP))} - copy(addr.IP, v.IP) - addr.Port = uint16(v.Port) - return addr, nil + addr, ok := netip.AddrFromSlice(v.IP) + if !ok { + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP) + } + return netip.AddrPortFrom(addr, uint16(v.Port)), nil default: - return nil, fmt.Errorf("LocalAddr returned: %#v", a) + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a) } } func (u *GenericConn) ReloadConfig(c *config.C) { - // TODO + } func NewUDPStatsEmitter(udpConns []Conn) func() { @@ -70,24 +70,17 @@ type rawMessage struct { Len uint32 } -func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) +func (u *GenericConn) ListenOut(r EncReader) { buffer := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - udpAddr := &Addr{IP: make([]byte, 16)} - nb := make([]byte, 12, 12) for { // Just read one packet at a time - n, rua, err := u.ReadFromUDP(buffer) + n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return } - udpAddr.IP = rua.IP - udpAddr.Port = uint16(rua.Port) - r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 1151c89..f1936b4 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -7,19 +7,16 @@ import ( "encoding/binary" "fmt" "net" + "net/netip" "syscall" "unsafe" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" "golang.org/x/sys/unix" ) -//TODO: make it support reload as best you can! - type StdConn struct { sysFd int isV4 bool @@ -27,25 +24,6 @@ type StdConn struct { batch int } -var x int - -// From linux/sock_diag.h -const ( - _SK_MEMINFO_RMEM_ALLOC = iota - _SK_MEMINFO_RCVBUF - _SK_MEMINFO_WMEM_ALLOC - _SK_MEMINFO_SNDBUF - _SK_MEMINFO_FWD_ALLOC - _SK_MEMINFO_WMEM_QUEUED - _SK_MEMINFO_OPTMEM - _SK_MEMINFO_BACKLOG - _SK_MEMINFO_DROPS - - _SK_MEMINFO_VARS -) - -type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32 - func maybeIPV4(ip net.IP) (net.IP, bool) { ip4 := ip.To4() if ip4 != nil { @@ -54,10 +32,9 @@ func maybeIPV4(ip net.IP) (net.IP, bool) { return ip, false } -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { - ipV4, isV4 := maybeIPV4(ip) +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { af := unix.AF_INET6 - if isV4 { + if ip.Is4() { af = unix.AF_INET } syscall.ForkLock.RLock() @@ -78,27 +55,21 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) ( } } - //TODO: support multiple listening IPs (for limiting ipv6) var sa unix.Sockaddr - if isV4 { + if ip.Is4() { sa4 := &unix.SockaddrInet4{Port: port} - copy(sa4.Addr[:], ipV4) + sa4.Addr = ip.As4() sa = sa4 } else { sa6 := &unix.SockaddrInet6{Port: port} - copy(sa6.Addr[:], ip.To16()) + sa6.Addr = ip.As16() sa = sa6 } if err = unix.Bind(fd, sa); err != nil { return nil, fmt.Errorf("unable to bind to socket: %s", err) } - //TODO: this may be useful for forcing threads into specific cores - //unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU, x) - //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) - //l.Println(v, err) - - return &StdConn{sysFd: fd, isV4: isV4, l: l, batch: batch}, err + return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err } func (u *StdConn) Rebind() error { @@ -113,6 +84,10 @@ func (u *StdConn) SetSendBuffer(n int) error { return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) } +func (u *StdConn) SetSoMark(mark int) error { + return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_MARK, mark) +} + func (u *StdConn) GetRecvBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF) } @@ -121,34 +96,31 @@ func (u *StdConn) GetSendBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) } -func (u *StdConn) LocalAddr() (*Addr, error) { - sa, err := unix.Getsockname(u.sysFd) - if err != nil { - return nil, err - } - - addr := &Addr{} - switch sa := sa.(type) { - case *unix.SockaddrInet4: - addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16() - addr.Port = uint16(sa.Port) - case *unix.SockaddrInet6: - addr.IP = sa.Addr[0:] - addr.Port = uint16(sa.Port) - } - - return addr, nil +func (u *StdConn) GetSoMark() (int, error) { + return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_MARK) } -func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - udpAddr := &Addr{} - nb := make([]byte, 12, 12) +func (u *StdConn) LocalAddr() (netip.AddrPort, error) { + sa, err := unix.Getsockname(u.sysFd) + if err != nil { + return netip.AddrPort{}, err + } + + switch sa := sa.(type) { + case *unix.SockaddrInet4: + return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil + + case *unix.SockaddrInet6: + return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil + + default: + return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa) + } +} + +func (u *StdConn) ListenOut(r EncReader) { + var ip netip.Addr - //TODO: should we track this? - //metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015)) msgs, buffers, names := u.PrepareRawMessages(u.batch) read := u.ReadMulti if u.batch == 1 { @@ -162,15 +134,14 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } - //metric.Update(int64(n)) for i := 0; i < n; i++ { + // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if u.isV4 { - udpAddr.IP = names[i][4:8] + ip, _ = netip.AddrFromSlice(names[i][4:8]) } else { - udpAddr.IP = names[i][8:24] + ip, _ = netip.AddrFromSlice(names[i][8:24]) } - udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) - r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) } } } @@ -216,19 +187,18 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { } } -func (u *StdConn) WriteTo(b []byte, addr *Addr) error { +func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { if u.isV4 { - return u.writeTo4(b, addr) + return u.writeTo4(b, ip) } - return u.writeTo6(b, addr) + return u.writeTo6(b, ip) } -func (u *StdConn) writeTo6(b []byte, addr *Addr) error { +func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 - // Little Endian -> Network Endian - rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) - copy(rsa.Addr[:], addr.IP.To16()) + rsa.Addr = ip.Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) for { _, _, err := unix.Syscall6( @@ -245,23 +215,19 @@ func (u *StdConn) writeTo6(b []byte, addr *Addr) error { return &net.OpError{Op: "sendto", Err: err} } - //TODO: handle incomplete writes - return nil } } -func (u *StdConn) writeTo4(b []byte, addr *Addr) error { - addrV4, isAddrV4 := maybeIPV4(addr.IP) - if !isAddrV4 { +func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { + if !ip.Addr().Is4() { return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") } var rsa unix.RawSockaddrInet4 rsa.Family = unix.AF_INET - // Little Endian -> Network Endian - rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) - copy(rsa.Addr[:], addrV4) + rsa.Addr = ip.Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) for { _, _, err := unix.Syscall6( @@ -278,8 +244,6 @@ func (u *StdConn) writeTo4(b []byte, addr *Addr) error { return &net.OpError{Op: "sendto", Err: err} } - //TODO: handle incomplete writes - return nil } } @@ -314,10 +278,26 @@ func (u *StdConn) ReloadConfig(c *config.C) { u.l.WithError(err).Error("Failed to set listen.write_buffer") } } + + b = c.GetInt("listen.so_mark", 0) + s, err := u.GetSoMark() + if b > 0 || (err == nil && s != 0) { + err := u.SetSoMark(b) + if err == nil { + s, err := u.GetSoMark() + if err == nil { + u.l.WithField("mark", s).Info("listen.so_mark was set") + } else { + u.l.WithError(err).Warn("Failed to get listen.so_mark") + } + } else { + u.l.WithError(err).Error("Failed to set listen.so_mark") + } + } } -func (u *StdConn) getMemInfo(meminfo *_SK_MEMINFO) error { - var vallen uint32 = 4 * _SK_MEMINFO_VARS +func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { + var vallen uint32 = 4 * unix.SK_MEMINFO_VARS _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) if err != 0 { return err @@ -326,18 +306,17 @@ func (u *StdConn) getMemInfo(meminfo *_SK_MEMINFO) error { } func (u *StdConn) Close() error { - //TODO: this will not interrupt the read loop return syscall.Close(u.sysFd) } func NewUDPStatsEmitter(udpConns []Conn) func() { // Check if our kernel supports SO_MEMINFO before registering the gauges - var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge - var meminfo _SK_MEMINFO + var udpGauges [][unix.SK_MEMINFO_VARS]metrics.Gauge + var meminfo [unix.SK_MEMINFO_VARS]uint32 if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil { - udpGauges = make([][_SK_MEMINFO_VARS]metrics.Gauge, len(udpConns)) + udpGauges = make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(udpConns)) for i := range udpConns { - udpGauges[i] = [_SK_MEMINFO_VARS]metrics.Gauge{ + udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{ metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil), @@ -354,7 +333,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { return func() { for i, gauges := range udpGauges { if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil { - for j := 0; j < _SK_MEMINFO_VARS; j++ { + for j := 0; j < unix.SK_MEMINFO_VARS; j++ { gauges[j].Update(int64(meminfo[j])) } } diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index 523968c..de8f1cd 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -39,7 +39,6 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { buffers[i] = make([]byte, MTU) names[i] = make([]byte, unix.SizeofSockaddrInet6) - //TODO: this is still silly, no need for an array vs := []iovec{ {Base: &buffers[i][0], Len: uint32(len(buffers[i]))}, } diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 87a0de7..48c5a97 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -42,7 +42,6 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { buffers[i] = make([]byte, MTU) names[i] = make([]byte, unix.SizeofSockaddrInet6) - //TODO: this is still silly, no need for an array vs := []iovec{ {Base: &buffers[i][0], Len: uint64(len(buffers[i]))}, } diff --git a/udp/udp_netbsd.go b/udp/udp_netbsd.go index 3c14fac..3b69159 100644 --- a/udp/udp_netbsd.go +++ b/udp/udp_netbsd.go @@ -8,13 +8,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 31c1a55..585b642 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "net" + "net/netip" "sync" "sync/atomic" "syscall" @@ -17,9 +18,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" - "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/conn/winrio" ) @@ -61,16 +59,14 @@ type RIOConn struct { results [packetsPerRing]winrio.Result } -func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) { +func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) { if !winrio.Initialize() { return nil, errors.New("could not initialize winrio") } u := &RIOConn{l: l} - addr := [16]byte{} - copy(addr[:], ip.To16()) - err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port}) + err := u.bind(&windows.SockaddrInet6{Addr: addr.As16(), Port: port}) if err != nil { return nil, fmt.Errorf("bind: %w", err) } @@ -119,13 +115,8 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error { return nil } -func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) +func (u *RIOConn) ListenOut(r EncReader) { buffer := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - udpAddr := &Addr{IP: make([]byte, 16)} - nb := make([]byte, 12, 12) for { // Just read one packet at a time @@ -135,11 +126,7 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } - udpAddr.IP = rua.Addr[:] - p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port)) - p[0] = byte(rua.Port >> 8) - p[1] = byte(rua.Port) - r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) } } @@ -231,7 +218,7 @@ retry: return n, ep, nil } -func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { +func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error { if !u.isOpen.Load() { return net.ErrClosed } @@ -274,10 +261,9 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { packet := u.tx.Push() packet.addr.Family = windows.AF_INET6 - p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port)) - p[0] = byte(addr.Port >> 8) - p[1] = byte(addr.Port) - copy(packet.addr.Addr[:], addr.IP.To16()) + packet.addr.Addr = ip.Addr().As16() + port := ip.Port() + packet.addr.Port = (port >> 8) | ((port & 0xff) << 8) copy(packet.data[:], buf) dataBuffer := &winrio.Buffer{ @@ -295,17 +281,15 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } -func (u *RIOConn) LocalAddr() (*Addr, error) { +func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { sa, err := windows.Getsockname(u.sock) if err != nil { - return nil, err + return netip.AddrPort{}, err } v6 := sa.(*windows.SockaddrInet6) - return &Addr{ - IP: v6.Addr[:], - Port: uint16(v6.Port), - }, nil + return netip.AddrPortFrom(netip.AddrFrom16(v6.Addr).Unmap(), uint16(v6.Port)), nil + } func (u *RIOConn) Rebind() error { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 55985f4..8d5e6c1 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -4,42 +4,34 @@ package udp import ( - "fmt" "io" - "net" + "net/netip" "sync/atomic" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" ) type Packet struct { - ToIp net.IP - ToPort uint16 - FromIp net.IP - FromPort uint16 - Data []byte + To netip.AddrPort + From netip.AddrPort + Data []byte } func (u *Packet) Copy() *Packet { n := &Packet{ - ToIp: make(net.IP, len(u.ToIp)), - ToPort: u.ToPort, - FromIp: make(net.IP, len(u.FromIp)), - FromPort: u.FromPort, - Data: make([]byte, len(u.Data)), + To: u.To, + From: u.From, + Data: make([]byte, len(u.Data)), } - copy(n.ToIp, u.ToIp) - copy(n.FromIp, u.FromIp) copy(n.Data, u.Data) return n } type TesterConn struct { - Addr *Addr + Addr netip.AddrPort RxPackets chan *Packet // Packets to receive into nebula TxPackets chan *Packet // Packets transmitted outside by nebula @@ -48,9 +40,9 @@ type TesterConn struct { l *logrus.Logger } -func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { return &TesterConn{ - Addr: &Addr{ip, uint16(port)}, + Addr: netip.AddrPortFrom(ip, uint16(port)), RxPackets: make(chan *Packet, 10), TxPackets: make(chan *Packet, 10), l: l, @@ -71,7 +63,7 @@ func (u *TesterConn) Send(packet *Packet) { } if u.l.Level >= logrus.DebugLevel { u.l.WithField("header", h). - WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)). + WithField("udpAddr", packet.From). WithField("dataLen", len(packet.Data)). Debug("UDP receiving injected packet") } @@ -98,42 +90,29 @@ func (u *TesterConn) Get(block bool) *Packet { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (u *TesterConn) WriteTo(b []byte, addr *Addr) error { +func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { if u.closed.Load() { return io.ErrClosedPipe } p := &Packet{ - Data: make([]byte, len(b), len(b)), - FromIp: make([]byte, 16), - FromPort: u.Addr.Port, - ToIp: make([]byte, 16), - ToPort: addr.Port, + Data: make([]byte, len(b), len(b)), + From: u.Addr, + To: addr, } copy(p.Data, b) - copy(p.ToIp, addr.IP.To16()) - copy(p.FromIp, u.Addr.IP.To16()) - u.TxPackets <- p return nil } -func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - ua := &Addr{IP: make([]byte, 16)} - nb := make([]byte, 12, 12) - +func (u *TesterConn) ListenOut(r EncReader) { for { p, ok := <-u.RxPackets if !ok { return } - ua.Port = p.FromPort - copy(ua.IP, p.FromIp.To16()) - r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(p.From, p.Data) } } @@ -144,7 +123,7 @@ func NewUDPStatsEmitter(_ []Conn) func() { return func() {} } -func (u *TesterConn) LocalAddr() (*Addr, error) { +func (u *TesterConn) LocalAddr() (netip.AddrPort, error) { return u.Addr, nil } diff --git a/udp/udp_windows.go b/udp/udp_windows.go index ebcace6..1b777c3 100644 --- a/udp/udp_windows.go +++ b/udp/udp_windows.go @@ -6,12 +6,13 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { if multi { //NOTE: Technically we can support it with RIO but it wouldn't be at the socket level // The udp stack would need to be reworked to hide away the implementation differences between diff --git a/util/error.go b/util/error.go index d7710f9..814c77a 100644 --- a/util/error.go +++ b/util/error.go @@ -9,11 +9,11 @@ import ( type ContextualError struct { RealError error - Fields map[string]interface{} + Fields map[string]any Context string } -func NewContextualError(msg string, fields map[string]interface{}, realError error) *ContextualError { +func NewContextualError(msg string, fields map[string]any, realError error) *ContextualError { return &ContextualError{Context: msg, Fields: fields, RealError: realError} } diff --git a/util/error_test.go b/util/error_test.go index 5041f82..692c184 100644 --- a/util/error_test.go +++ b/util/error_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" ) -type m map[string]interface{} +type m = map[string]any type TestLogWriter struct { Logs []string