From 2400e2392be3fcb87bf6ef3badc30ee6689ba848 Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Wed, 2 Apr 2025 16:24:03 -0400 Subject: [PATCH] lint * reduce staticcheck warnings --- allow_list_test.go | 12 ++++++------ bits.go | 10 +++++----- cert/cert_v1.go | 2 -- cert/cert_v2_test.go | 10 ++++++---- cert/crypto.go | 3 +++ cert/crypto_test.go | 4 ++++ cert/helper_test.go | 3 +++ cert/pem_test.go | 9 +++++++++ cert/sign_test.go | 1 + cert_test/cert.go | 3 +++ cmd/nebula-cert/ca.go | 2 +- cmd/nebula-cert/keygen.go | 2 +- cmd/nebula-cert/main_test.go | 3 +-- cmd/nebula-cert/p11_stub.go | 2 +- cmd/nebula-cert/print.go | 6 +++--- cmd/nebula-cert/verify.go | 4 ++-- cmd/nebula-cert/verify_test.go | 2 +- cmd/nebula-service/service.go | 5 +---- config/config.go | 2 +- connection_manager.go | 10 +++++----- connection_manager_test.go | 6 +++--- control.go | 8 +++----- control_test.go | 6 ++---- firewall.go | 12 ++++++------ firewall_test.go | 5 +++++ handshake_ix.go | 8 +++----- handshake_manager_test.go | 14 -------------- header/header.go | 2 +- hostmap.go | 8 ++++---- interface.go | 8 ++++---- lighthouse.go | 30 ++++++++++-------------------- lighthouse_test.go | 8 ++++---- message_metrics.go | 4 ++-- outside.go | 2 +- overlay/route.go | 27 --------------------------- overlay/route_test.go | 3 ++- pkclient/pkclient.go | 23 ----------------------- pkclient/pkclient_cgo.go | 23 +++++++++++++++++++++++ pkclient/pkclient_stub.go | 10 +++++----- remote_list.go | 23 +++++++---------------- ssh.go | 20 ++++++++++---------- sshd/server.go | 6 +++--- sshd/session.go | 1 - udp/conn.go | 8 ++------ udp/udp_generic.go | 6 +----- 45 files changed, 158 insertions(+), 208 deletions(-) diff --git a/allow_list_test.go b/allow_list_test.go index 6135f36..e74d23a 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -25,14 +25,14 @@ func TestNewAllowListFromConfig(t *testing.T) { c.Settings["allowlist"] = map[string]any{ "192.168.0.0/16": "abc", } - r, err = newAllowListFromConfig(c, "allowlist", nil) + _, err = newAllowListFromConfig(c, "allowlist", nil) require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") c.Settings["allowlist"] = map[string]any{ "192.168.0.0/16": true, "10.0.0.0/8": false, } - r, err = newAllowListFromConfig(c, "allowlist", nil) + _, err = newAllowListFromConfig(c, "allowlist", nil) require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") c.Settings["allowlist"] = map[string]any{ @@ -42,7 +42,7 @@ func TestNewAllowListFromConfig(t *testing.T) { "fd00::/8": true, "fd00:fd00::/16": false, } - r, err = newAllowListFromConfig(c, "allowlist", nil) + _, err = newAllowListFromConfig(c, "allowlist", nil) require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") c.Settings["allowlist"] = map[string]any{ @@ -75,7 +75,7 @@ func TestNewAllowListFromConfig(t *testing.T) { `docker.*`: "foo", }, } - lr, err := NewLocalAllowListFromConfig(c, "allowlist") + _, err = NewLocalAllowListFromConfig(c, "allowlist") require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") c.Settings["allowlist"] = map[string]any{ @@ -84,7 +84,7 @@ func TestNewAllowListFromConfig(t *testing.T) { `eth.*`: true, }, } - lr, err = NewLocalAllowListFromConfig(c, "allowlist") + _, err = NewLocalAllowListFromConfig(c, "allowlist") require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") c.Settings["allowlist"] = map[string]any{ @@ -92,7 +92,7 @@ func TestNewAllowListFromConfig(t *testing.T) { `docker.*`: false, }, } - lr, err = NewLocalAllowListFromConfig(c, "allowlist") + lr, err := NewLocalAllowListFromConfig(c, "allowlist") if assert.NoError(t, err) { assert.NotNil(t, lr) } diff --git a/bits.go b/bits.go index b4f96c6..f8f6121 100644 --- a/bits.go +++ b/bits.go @@ -18,7 +18,7 @@ type Bits struct { func NewBits(bits uint64) *Bits { return &Bits{ length: bits, - bits: make([]bool, bits, bits), + bits: make([]bool, bits), current: 0, lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil), dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil), @@ -28,7 +28,7 @@ func NewBits(bits uint64) *Bits { func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool { // If i is the next number, return true. - if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) { + if i > b.current || (i == 0 && !b.firstSeen && b.current < b.length) { return true } @@ -51,7 +51,7 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool { // If i is the next number, return true and update current. if i == b.current+1 { // Report missed packets, we can only understand what was missed after the first window has been gone through - if i > b.length && b.bits[i%b.length] == false { + if i > b.length && !b.bits[i%b.length] { b.lostCounter.Inc(1) } b.bits[i%b.length] = true @@ -104,7 +104,7 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool { } // Allow for the 0 packet to come in within the first window - if i == 0 && b.firstSeen == false && b.current < b.length { + if i == 0 && !b.firstSeen && b.current < b.length { b.firstSeen = true b.bits[i%b.length] = true return true @@ -122,7 +122,7 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool { return false } - if b.bits[i%b.length] == true { + if b.bits[i%b.length] { if l.Level >= logrus.DebugLevel { l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}). Debug("Receive window") diff --git a/cert/cert_v1.go b/cert/cert_v1.go index 71d36eb..5e02423 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -20,8 +20,6 @@ import ( "google.golang.org/protobuf/proto" ) -const publicKeyLen = 32 - type certificateV1 struct { details detailsV1 signature []byte diff --git a/cert/cert_v2_test.go b/cert/cert_v2_test.go index c84f8c9..91dd844 100644 --- a/cert/cert_v2_test.go +++ b/cert/cert_v2_test.go @@ -113,14 +113,14 @@ func TestCertificateV2_MarshalJSON(t *testing.T) { signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"), } - b, err := nc.MarshalJSON() + _, err := nc.MarshalJSON() require.ErrorIs(t, err, ErrMissingDetails) rd, err := nc.details.Marshal() require.NoError(t, err) nc.rawDetails = rd - b, err = nc.MarshalJSON() + b, err := nc.MarshalJSON() require.NoError(t, err) assert.JSONEq( t, @@ -174,8 +174,9 @@ func TestCertificateV2_VerifyPrivateKey(t *testing.T) { require.ErrorIs(t, err, ErrInvalidPrivateKey) c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil) - rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv) - + _, _, curve, err = UnmarshalPrivateKeyFromPEM(priv) + assert.Equal(t, err, nil) + assert.Equal(t, curve, Curve_P256) err = c.VerifyPrivateKey(Curve_P256, priv[:16]) require.ErrorIs(t, err, ErrInvalidPrivateKey) @@ -261,6 +262,7 @@ func TestCertificateV2_marshalForSigningStability(t *testing.T) { assert.Equal(t, expectedRawDetails, db) expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162") + require.NoError(t, err) b, err := nc.marshalForSigning() require.NoError(t, err) assert.Equal(t, expectedForSigning, b) diff --git a/cert/crypto.go b/cert/crypto.go index 4c236ae..1b77a18 100644 --- a/cert/crypto.go +++ b/cert/crypto.go @@ -227,6 +227,9 @@ func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) { } func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) { + // Are we testing the compilers types here? + // No value of int32 is lewss than math.MinInt32. + // By definition these checks can never be true. if params.Version < math.MinInt32 || params.Version > math.MaxInt32 { return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32) } diff --git a/cert/crypto_test.go b/cert/crypto_test.go index 6358ba6..77dcd79 100644 --- a/cert/crypto_test.go +++ b/cert/crypto_test.go @@ -72,12 +72,14 @@ qrlJ69wer3ZUHFXA require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + assert.Equal(t, curve, Curve_CURVE25519) // Fail due to invalid banner curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") assert.Nil(t, k) assert.Equal(t, rest, invalidPem) + assert.Equal(t, curve, Curve_CURVE25519) // Fail due to ivalid PEM format, because // it's missing the requisite pre-encapsulation boundary. @@ -85,12 +87,14 @@ qrlJ69wer3ZUHFXA require.EqualError(t, err, "input did not contain a valid PEM encoded block") assert.Nil(t, k) assert.Equal(t, rest, invalidPem) + assert.Equal(t, curve, Curve_CURVE25519) // Fail due to invalid passphrase curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) require.EqualError(t, err, "invalid passphrase or corrupt private key") assert.Nil(t, k) assert.Equal(t, []byte{}, rest) + assert.Equal(t, curve, Curve_CURVE25519) } func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { diff --git a/cert/helper_test.go b/cert/helper_test.go index 1b72a0f..aed1bac 100644 --- a/cert/helper_test.go +++ b/cert/helper_test.go @@ -21,6 +21,9 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ switch curve { case Curve_CURVE25519: pub, priv, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } case Curve_P256: privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { diff --git a/cert/pem_test.go b/cert/pem_test.go index 6e49249..1f416ac 100644 --- a/cert/pem_test.go +++ b/cert/pem_test.go @@ -97,12 +97,14 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA // Fail due to short key k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") // Fail due to invalid banner k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, rest, invalidPem) require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner") @@ -110,6 +112,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA // it's missing the requisite pre-encapsulation boundary. k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, rest, invalidPem) require.EqualError(t, err, "input did not contain a valid PEM encoded block") } @@ -159,12 +162,14 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= // Fail due to short key k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") // Fail due to invalid banner k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, rest, invalidPem) require.EqualError(t, err, "bytes did not contain a proper private key banner") @@ -172,6 +177,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= // it's missing the requisite pre-encapsulation boundary. k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, rest, invalidPem) require.EqualError(t, err, "input did not contain a valid PEM encoded block") } @@ -275,12 +281,14 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= // Fail due to short key k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") // Fail due to invalid banner k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) require.EqualError(t, err, "bytes did not contain a proper public key banner") assert.Equal(t, rest, invalidPem) @@ -288,6 +296,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= // it's missing the requisite pre-encapsulation boundary. k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, rest, invalidPem) require.EqualError(t, err, "input did not contain a valid PEM encoded block") } diff --git a/cert/sign_test.go b/cert/sign_test.go index e6f43cd..96937fd 100644 --- a/cert/sign_test.go +++ b/cert/sign_test.go @@ -37,6 +37,7 @@ func TestCertificateV1_Sign(t *testing.T) { } pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv) require.NoError(t, err) assert.NotNil(t, c) diff --git a/cert_test/cert.go b/cert_test/cert.go index ebc6f52..ad96f81 100644 --- a/cert_test/cert.go +++ b/cert_test/cert.go @@ -22,6 +22,9 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti switch curve { case cert.Curve_CURVE25519: pub, priv, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } case cert.Curve_P256: privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index f83c94f..0c9e20d 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -81,7 +81,7 @@ func parseArgonParameters(memory uint, parallelism uint, iterations uint) (*cert return cert.NewArgon2Parameters(uint32(memory), uint8(parallelism), uint32(iterations)), nil } -func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error { +func ca(args []string, out io.Writer, _ io.Writer, pr PasswordReader) error { cf := newCaFlags() err := cf.set.Parse(args) if err != nil { diff --git a/cmd/nebula-cert/keygen.go b/cmd/nebula-cert/keygen.go index 496f84c..25c54d7 100644 --- a/cmd/nebula-cert/keygen.go +++ b/cmd/nebula-cert/keygen.go @@ -29,7 +29,7 @@ func newKeygenFlags() *keygenFlags { return &cf } -func keygen(args []string, out io.Writer, errOut io.Writer) error { +func keygen(args []string, _ io.Writer, _ io.Writer) error { cf := newKeygenFlags() err := cf.set.Parse(args) if err != nil { diff --git a/cmd/nebula-cert/main_test.go b/cmd/nebula-cert/main_test.go index 2e92e7e..4c85c28 100644 --- a/cmd/nebula-cert/main_test.go +++ b/cmd/nebula-cert/main_test.go @@ -3,7 +3,6 @@ package main import ( "bytes" "errors" - "fmt" "io" "os" "testing" @@ -77,7 +76,7 @@ func assertHelpError(t *testing.T, err error, msg string) { case *helpError: // good default: - t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg)) + t.Fatalf("err was not a helpError: %q, expected %q", err, msg) } require.EqualError(t, err, msg) diff --git a/cmd/nebula-cert/p11_stub.go b/cmd/nebula-cert/p11_stub.go index 5afeaea..91ad4f8 100644 --- a/cmd/nebula-cert/p11_stub.go +++ b/cmd/nebula-cert/p11_stub.go @@ -10,7 +10,7 @@ func p11Supported() bool { return false } -func p11Flag(set *flag.FlagSet) *string { +func p11Flag(_ *flag.FlagSet) *string { var ret = "" return &ret } diff --git a/cmd/nebula-cert/print.go b/cmd/nebula-cert/print.go index 30e0965..6cd3b94 100644 --- a/cmd/nebula-cert/print.go +++ b/cmd/nebula-cert/print.go @@ -1,12 +1,12 @@ package main import ( + "bytes" "encoding/json" "flag" "fmt" "io" "os" - "strings" "github.com/skip2/go-qrcode" "github.com/slackhq/nebula/cert" @@ -29,7 +29,7 @@ func newPrintFlags() *printFlags { return &pf } -func printCert(args []string, out io.Writer, errOut io.Writer) error { +func printCert(args []string, out io.Writer, _ io.Writer) error { pf := newPrintFlags() err := pf.set.Parse(args) if err != nil { @@ -72,7 +72,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { qrBytes = append(qrBytes, b...) } - if rawCert == nil || len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" { + if len(rawCert) == 0 || len(bytes.TrimSpace(rawCert)) == 0 { break } diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go index bea4d1d..d9a4bb6 100644 --- a/cmd/nebula-cert/verify.go +++ b/cmd/nebula-cert/verify.go @@ -1,12 +1,12 @@ package main import ( + "bytes" "errors" "flag" "fmt" "io" "os" - "strings" "time" "github.com/slackhq/nebula/cert" @@ -52,7 +52,7 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while adding ca cert to pool: %w", err) } - if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" { + if len(rawCACert) == 0 || len(bytes.TrimSpace(rawCACert)) == 0 { break } } diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index f555e5f..920f53c 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -97,7 +97,7 @@ func Test_verify(t *testing.T) { crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) // Slightly evil hack to modify the certificate after it was sealed to generate an invalid signature pub := crt.PublicKey() - for i, _ := range pub { + for i := range pub { pub[i] = 0 } b, _ = crt.MarshalPEM() diff --git a/cmd/nebula-service/service.go b/cmd/nebula-service/service.go index a54fb0f..d2d701d 100644 --- a/cmd/nebula-service/service.go +++ b/cmd/nebula-service/service.go @@ -51,10 +51,7 @@ func (p *program) Stop(s service.Service) error { func fileExists(filename string) bool { _, err := os.Stat(filename) - if os.IsNotExist(err) { - return false - } - return true + return !os.IsNotExist(err) } func doService(configPath *string, configTest *bool, build string, serviceFlag *string) { diff --git a/config/config.go b/config/config.go index 5510324..f80ac58 100644 --- a/config/config.go +++ b/config/config.go @@ -63,7 +63,7 @@ func (c *C) Load(path string) error { func (c *C) LoadString(raw string) error { if raw == "" { - return errors.New("Empty configuration") + return errors.New("empty configuration") } return c.parseRaw([]byte(raw)) } diff --git a/connection_manager.go b/connection_manager.go index 5c9b3a5..41e6c69 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -154,7 +154,7 @@ func (n *connectionManager) Run(ctx context.Context) { defer clockSource.Stop() p := []byte("") - nb := make([]byte, 12, 12) + nb := make([]byte, 12) out := make([]byte, mtu) for { @@ -355,7 +355,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time decision = tryRehandshake } else { - if n.shouldSwapPrimary(hostinfo, primary) { + if n.shouldSwapPrimary(hostinfo) { decision = swapPrimary } else { // migrate the relays to the primary, if in use. @@ -384,7 +384,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time } decision := doNothing - if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { + if hostinfo.ConnectionState != nil && mainHostInfo { if !outTraffic { // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. // Just maintain NAT state if configured to do so. @@ -421,7 +421,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time return decision, hostinfo, nil } -func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { +func (n *connectionManager) shouldSwapPrimary(current *HostInfo) bool { // The primary tunnel is the most recent handshake to complete locally and should work entirely fine. // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // Let's sort this out. @@ -498,7 +498,7 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { cs := n.intf.pki.getCertState() curCrt := hostinfo.ConnectionState.myCert myCrt := cs.getCertificate(curCrt.Version()) - if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { + if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) { // The current tunnel is using the latest certificate and version, no need to rehandshake. return } diff --git a/connection_manager_test.go b/connection_manager_test.go index d1c5ba3..08d748b 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -69,7 +69,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { punchy := NewPunchyFromConfig(l, config.NewC(l)) nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) p := []byte("") - nb := make([]byte, 12, 12) + nb := make([]byte, 12) out := make([]byte, mtu) // Add an ip we have established a connection w/ to hostmap @@ -151,7 +151,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { punchy := NewPunchyFromConfig(l, config.NewC(l)) nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) p := []byte("") - nb := make([]byte, 12, 12) + nb := make([]byte, 12) out := make([]byte, mtu) // Add an ip we have established a connection w/ to hostmap @@ -241,7 +241,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { require.NoError(t, err) cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) - + require.NoError(t, err) cs := &CertState{ privateKey: []byte{}, v1Cert: &dummyCert{}, diff --git a/control.go b/control.go index 20dd7fe..6a46a3f 100644 --- a/control.go +++ b/control.go @@ -215,7 +215,7 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { hostInfo.ConnectionState, hostInfo, []byte{}, - make([]byte, 12, 12), + make([]byte, 12), make([]byte, mtu), ) } @@ -231,7 +231,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) { return } - c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) + c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12), make([]byte, mtu)) c.f.closeTunnel(h) c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote). @@ -282,9 +282,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { CurrentRemote: h.remote, } - for i, a := range h.vpnAddrs { - chi.VpnAddrs[i] = a - } + copy(chi.VpnAddrs, h.vpnAddrs) if h.ConnectionState != nil { chi.MessageCounter = h.ConnectionState.messageCounter.Load() diff --git a/control_test.go b/control_test.go index e400992..07f7d54 100644 --- a/control_test.go +++ b/control_test.go @@ -26,13 +26,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444") ipNet := net.IPNet{ - IP: remote1.Addr().AsSlice(), - Mask: net.IPMask{255, 255, 255, 0}, + IP: remote1.Addr().AsSlice(), } ipNet2 := net.IPNet{ - IP: remote2.Addr().AsSlice(), - Mask: net.IPMask{255, 255, 255, 0}, + IP: remote2.Addr().AsSlice(), } remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil) diff --git a/firewall.go b/firewall.go index e730114..d529996 100644 --- a/firewall.go +++ b/firewall.go @@ -606,7 +606,7 @@ func (f *Firewall) evict(p firewall.Packet) { return } - newT := t.Expires.Sub(time.Now()) + newT := time.Until(t.Expires) // Timeout is in the future, re-add the timer if newT > 0 { @@ -832,7 +832,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool } // Shortcut path for if groups, hosts, or cidr contained an `any` - if fr.Any.match(p, c) { + if fr.Any.match(p) { return true } @@ -849,21 +849,21 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool found = true } - if found && sg.LocalCIDR.match(p, c) { + if found && sg.LocalCIDR.match(p) { return true } } if fr.Hosts != nil { if flc, ok := fr.Hosts[c.Certificate.Name()]; ok { - if flc.match(p, c) { + if flc.match(p) { return true } } } for _, v := range fr.CIDR.Supernets(netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())) { - if v.match(p, c) { + if v.match(p) { return true } } @@ -892,7 +892,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { return nil } -func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate) bool { +func (flc *firewallLocalCIDR) match(p firewall.Packet) bool { if flc == nil { return false } diff --git a/firewall_test.go b/firewall_test.go index 4731a6f..2c86356 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -35,22 +35,27 @@ func TestNewFirewall(t *testing.T) { assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c) + conntrack = fw.Conntrack assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c) + conntrack = fw.Conntrack assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c) + conntrack = fw.Conntrack assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c) + conntrack = fw.Conntrack assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c) + conntrack = fw.Conntrack assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) } diff --git a/handshake_ix.go b/handshake_ix.go index 571a19a..1dfaa2f 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -343,7 +343,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if existing.SetRemoteIfPreferred(f.hostMap, addr) { // Send a test packet to ensure the other side has also switched to // the preferred remote - f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12), make([]byte, mtu)) } msg = existing.HandshakePacket[2] @@ -386,7 +386,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet Info("Handshake too old") // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues - f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12), make([]byte, mtu)) return case ErrLocalIndexCollision: // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry @@ -461,8 +461,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) hostinfo.remotes.ResetBlockedRemotes() - - return } func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { @@ -660,7 +658,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha } if len(hh.packetStore) > 0 { - nb := make([]byte, 12, 12) + nb := make([]byte, 12) out := make([]byte, mtu) for _, cp := range hh.packetStore { cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 2e6d34b..a603f47 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -65,30 +65,16 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { assert.NotContains(t, blah.vpnIps, ip) } -func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) { - for _, i := range tw.t.wheel { - n := i.Head - for n != nil { - c++ - n = n.Next - } - } - return c -} - type mockEncWriter struct { } func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) { - return } func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) { - return } func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) { - return } func (mw *mockEncWriter) Handshake(_ netip.Addr) {} diff --git a/header/header.go b/header/header.go index f22509b..86dbeaa 100644 --- a/header/header.go +++ b/header/header.go @@ -23,7 +23,7 @@ type m = map[string]any const ( Version uint8 = 1 - Len = 16 + Len int = 16 ) type MessageType uint8 diff --git a/hostmap.go b/hostmap.go index f9e3c4e..6a9d58d 100644 --- a/hostmap.go +++ b/hostmap.go @@ -568,7 +568,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) } for _, addr := range hostinfo.vpnAddrs { - hm.unlockedInnerAddHostInfo(addr, hostinfo, f) + hm.unlockedInnerAddHostInfo(addr, hostinfo) } hm.Indexes[hostinfo.localIndexId] = hostinfo @@ -581,7 +581,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { } } -func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) { +func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo) { existing := hm.Hosts[vpnAddr] hm.Hosts[vpnAddr] = hostinfo @@ -648,7 +648,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac // Try to send a test packet to that host, this should // cause it to detect a roaming event and switch remotes - ifce.sendTo(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + ifce.sendTo(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12), make([]byte, mtu)) }) } @@ -794,7 +794,7 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { } addr = addr.Unmap() - if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false { + if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { isAllowed := allowList.Allow(addr) if l.Level >= logrus.TraceLevel { l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow") diff --git a/interface.go b/interface.go index a15e2c2..ecabd20 100644 --- a/interface.go +++ b/interface.go @@ -266,7 +266,7 @@ func (f *Interface) listenOut(i int) { plaintext := make([]byte, udp.MTU) h := &header.H{} fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) + nb := make([]byte, 12) li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) @@ -279,7 +279,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { packet := make([]byte, mtu) out := make([]byte, mtu) fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) + nb := make([]byte, 12) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) @@ -322,7 +322,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) { func (f *Interface) reloadFirewall(c *config.C) { //TODO: need to trigger/detect if the certificate changed too - if c.HasChanged("firewall") == false { + if !c.HasChanged("firewall") { f.l.Debug("No firewall config change detected") return } @@ -424,7 +424,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { certState := f.pki.getCertState() defaultCrt := certState.GetDefaultCertificate() - certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second)) + certExpirationGauge.Update(int64(time.Until(defaultCrt.NotAfter()) / time.Second)) certInitiatingVersion.Update(int64(defaultCrt.Version())) // Report the max certificate version we are capable of using diff --git a/lighthouse.go b/lighthouse.go index eb09a39..8cf3528 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -371,7 +371,7 @@ func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{ } staticList := lh.GetStaticHostList() - for lhAddr, _ := range lhMap { + for lhAddr := range lhMap { if _, ok := staticList[lhAddr]; !ok { return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr) } @@ -654,11 +654,8 @@ func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool { } _, found := lh.myVpnNetworksTable.Lookup(to) - if found { - return false - } - return true + return !found } // unlockedShouldAddV4 checks if to is allowed by our allow list @@ -675,11 +672,7 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo } _, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) - if found { - return false - } - - return true + return !found } // unlockedShouldAddV6 checks if to is allowed by our allow list @@ -696,11 +689,8 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo } _, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) - if found { - return false - } - return true + return !found } func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool { @@ -728,7 +718,7 @@ func (lh *LightHouse) startQueryWorker() { } go func() { - nb := make([]byte, 12, 12) + nb := make([]byte, 12) out := make([]byte, mtu) for { @@ -869,7 +859,7 @@ func (lh *LightHouse) SendUpdate() { } } - nb := make([]byte, 12, 12) + nb := make([]byte, 12) out := make([]byte, mtu) var v1Update, v2Update []byte @@ -971,7 +961,7 @@ type LightHouseHandler struct { func (lh *LightHouse) NewRequestHandler() *LightHouseHandler { lhh := &LightHouseHandler{ lh: lh, - nb: make([]byte, 12, 12), + nb: make([]byte, 12), out: make([]byte, mtu), l: lh.l, pb: make([]byte, mtu), @@ -1162,7 +1152,7 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul if c.v4.learned != nil { n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.learned) } - if c.v4.reported != nil && len(c.v4.reported) > 0 { + if len(c.v4.reported) > 0 { n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.reported...) } } @@ -1171,7 +1161,7 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul if c.v6.learned != nil { n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.learned) } - if c.v6.reported != nil && len(c.v6.reported) > 0 { + if len(c.v6.reported) > 0 { n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.reported...) } } @@ -1369,7 +1359,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine // managed by a channel. - w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12), make([]byte, mtu)) }() } } diff --git a/lighthouse_test.go b/lighthouse_test.go index c49615c..3b04614 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -484,12 +484,12 @@ func Test_findNetworkUnion(t *testing.T) { assert.Equal(t, out, afe81) //falsey cases - out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1}) + _, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1}) assert.False(t, ok) - out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1}) + _, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1}) assert.False(t, ok) - out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81}) + _, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81}) assert.False(t, ok) - out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81}) + _, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81}) assert.False(t, ok) } diff --git a/message_metrics.go b/message_metrics.go index 10e8472..4e6b672 100644 --- a/message_metrics.go +++ b/message_metrics.go @@ -17,7 +17,7 @@ type MessageMetrics struct { func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) { if m != nil { - if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) { + if int(t) < len(m.rx) && int(s) < len(m.rx[t]) { m.rx[t][s].Inc(i) } else if m.rxUnknown != nil { m.rxUnknown.Inc(i) @@ -26,7 +26,7 @@ func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int } func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int64) { if m != nil { - if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) { + if int(t) < len(m.tx) && int(s) < len(m.tx[t]) { m.tx[t][s].Inc(i) } else if m.txUnknown != nil { m.txUnknown.Inc(i) diff --git a/outside.go b/outside.go index 1e9cde1..3815bba 100644 --- a/outside.go +++ b/outside.go @@ -228,7 +228,7 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) { // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote func (f *Interface) sendCloseTunnel(h *HostInfo) { - f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) + f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12), make([]byte, mtu)) } func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) { diff --git a/overlay/route.go b/overlay/route.go index 6198958..641cf74 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -3,7 +3,6 @@ package overlay import ( "fmt" "math" - "net" "net/netip" "runtime" "strconv" @@ -305,29 +304,3 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { return routes, nil } - -func ipWithin(o *net.IPNet, i *net.IPNet) bool { - // Make sure o contains the lowest form of i - if !o.Contains(i.IP.Mask(i.Mask)) { - return false - } - - // Find the max ip in i - ip4 := i.IP.To4() - if ip4 == nil { - return false - } - - last := make(net.IP, len(ip4)) - copy(last, ip4) - for x := range ip4 { - last[x] |= ^i.Mask[x] - } - - // Make sure o contains the max - if !o.Contains(last) { - return false - } - - return true -} diff --git a/overlay/route_test.go b/overlay/route_test.go index 9a959a5..996095c 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -225,6 +225,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { // no mtu c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + require.NoError(t, err) assert.Len(t, routes, 1) assert.Equal(t, 0, routes[0].MTU) @@ -318,7 +319,7 @@ func Test_makeRouteTree(t *testing.T) { ip, err = netip.ParseAddr("1.1.0.1") require.NoError(t, err) - r, ok = routeTree.Lookup(ip) + _, ok = routeTree.Lookup(ip) assert.False(t, ok) } diff --git a/pkclient/pkclient.go b/pkclient/pkclient.go index 7061de6..2d4fc19 100644 --- a/pkclient/pkclient.go +++ b/pkclient/pkclient.go @@ -1,8 +1,6 @@ package pkclient import ( - "crypto/ecdsa" - "crypto/x509" "fmt" "io" "strconv" @@ -50,27 +48,6 @@ func FromUrl(pkurl string) (*PKClient, error) { return New(module, uint(slotid), pin, id, label) } -func ecKeyToArray(key *ecdsa.PublicKey) []byte { - x := make([]byte, 32) - y := make([]byte, 32) - key.X.FillBytes(x) - key.Y.FillBytes(y) - return append([]byte{0x04}, append(x, y...)...) -} - -func formatPubkeyFromPublicKeyInfoAttr(d []byte) ([]byte, error) { - e, err := x509.ParsePKIXPublicKey(d) - if err != nil { - return nil, err - } - switch t := e.(type) { - case *ecdsa.PublicKey: - return ecKeyToArray(e.(*ecdsa.PublicKey)), nil - default: - return nil, fmt.Errorf("unknown public key type: %T", t) - } -} - func (c *PKClient) Test() error { pub, err := c.GetPubKey() if err != nil { diff --git a/pkclient/pkclient_cgo.go b/pkclient/pkclient_cgo.go index a2ead55..3ad450a 100644 --- a/pkclient/pkclient_cgo.go +++ b/pkclient/pkclient_cgo.go @@ -3,6 +3,8 @@ package pkclient import ( + "crypto/ecdsa" + "crypto/x509" "encoding/asn1" "errors" "fmt" @@ -227,3 +229,24 @@ func (c *PKClient) GetPubKey() ([]byte, error) { return nil, fmt.Errorf("unknown public key length: %d", len(d)) } } + +func ecKeyToArray(key *ecdsa.PublicKey) []byte { + x := make([]byte, 32) + y := make([]byte, 32) + key.X.FillBytes(x) + key.Y.FillBytes(y) + return append([]byte{0x04}, append(x, y...)...) +} + +func formatPubkeyFromPublicKeyInfoAttr(d []byte) ([]byte, error) { + e, err := x509.ParsePKIXPublicKey(d) + if err != nil { + return nil, err + } + switch t := e.(type) { + case *ecdsa.PublicKey: + return ecKeyToArray(e.(*ecdsa.PublicKey)), nil + default: + return nil, fmt.Errorf("unknown public key type: %T", t) + } +} diff --git a/pkclient/pkclient_stub.go b/pkclient/pkclient_stub.go index 36b0fc9..3be993f 100644 --- a/pkclient/pkclient_stub.go +++ b/pkclient/pkclient_stub.go @@ -7,10 +7,10 @@ import "errors" type PKClient struct { } -var notImplemented = errors.New("not implemented") +var errNotImplemented = errors.New("not implemented") func New(hsmPath string, slotId uint, pin string, id string, label string) (*PKClient, error) { - return nil, notImplemented + return nil, errNotImplemented } func (c *PKClient) Close() error { @@ -18,13 +18,13 @@ func (c *PKClient) Close() error { } func (c *PKClient) SignASN1(data []byte) ([]byte, error) { - return nil, notImplemented + return nil, errNotImplemented } func (c *PKClient) DeriveNoise(_ []byte) ([]byte, error) { - return nil, notImplemented + return nil, errNotImplemented } func (c *PKClient) GetPubKey() ([]byte, error) { - return nil, notImplemented + return nil, errNotImplemented } diff --git a/remote_list.go b/remote_list.go index 6baed29..e9e6bec 100644 --- a/remote_list.go +++ b/remote_list.go @@ -263,9 +263,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort r.RLock() defer r.RUnlock() c := make([]netip.AddrPort, len(r.addrs)) - for i, v := range r.addrs { - c[i] = v - } + copy(c, r.addrs) return c } @@ -326,9 +324,7 @@ func (r *RemoteList) CopyCache() *CacheMap { } if mc.relay != nil { - for _, a := range mc.relay.relay { - c.Relay = append(c.Relay, a) - } + c.Relay = append(c.Relay, mc.relay.relay...) } } @@ -362,9 +358,7 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort { defer r.RUnlock() c := make([]netip.AddrPort, len(r.badRemotes)) - for i, v := range r.badRemotes { - c[i] = v - } + copy(c, r.badRemotes) return c } @@ -569,9 +563,7 @@ func (r *RemoteList) unlockedCollect() { } if c.relay != nil { - for _, v := range c.relay.relay { - relays = append(relays, v) - } + relays = append(relays, c.relay.relay...) } } @@ -635,15 +627,15 @@ func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) { a4 := a.Addr().Is4() b4 := b.Addr().Is4() switch { - case a4 == false && b4 == true: + case !a4 && b4: // If i is v6 and j is v4, i is less than j return true - case a4 == true && b4 == false: + case a4 && !b4: // If j is v6 and i is v4, i is not less than j return false - case a4 == true && b4 == true: + case a4 && b4: // i and j are both ipv4 aPrivate := a.Addr().IsPrivate() bPrivate := b.Addr().IsPrivate() @@ -691,7 +683,6 @@ func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) { } r.addrs = r.addrs[:a+1] - return } // minInt returns the minimum integer of a or b diff --git a/ssh.go b/ssh.go index 9a26c29..4e546d3 100644 --- a/ssh.go +++ b/ssh.go @@ -527,11 +527,11 @@ func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error { return err } -func sshVersion(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { - return w.WriteLine(fmt.Sprintf("%s", ifce.version)) +func sshVersion(ifce *Interface, _ any, _ []string, w sshd.StringWriter) error { + return w.WriteLine(ifce.version) } -func sshQueryLighthouse(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { +func sshQueryLighthouse(ifce *Interface, _ any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No vpn address was provided") } @@ -584,7 +584,7 @@ func sshCloseTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) er hostInfo.ConnectionState, hostInfo, []byte{}, - make([]byte, 12, 12), + make([]byte, 12), make([]byte, mtu), ) } @@ -614,12 +614,12 @@ func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) e hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo != nil { - return w.WriteLine(fmt.Sprintf("Tunnel already exists")) + return w.WriteLine("Tunnel already exists") } hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnAddr) if hostInfo != nil { - return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) + return w.WriteLine("Tunnel already handshaking") } var addr netip.AddrPort @@ -735,7 +735,7 @@ func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error { return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a)) } -func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { +func sshLogLevel(l *logrus.Logger, _ any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) } @@ -749,7 +749,7 @@ func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) erro return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) } -func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { +func sshLogFormat(l *logrus.Logger, _ any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) } @@ -822,10 +822,10 @@ func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) erro return w.WriteLine(cert.String()) } -func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { +func sshPrintRelays(ifce *Interface, fs any, _ []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintTunnelFlags) if !ok { - w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type")) + w.WriteLine("sshPrintRelays failed to convert args type") return nil } diff --git a/sshd/server.go b/sshd/server.go index a8b60ba..00fdac5 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -23,7 +23,6 @@ type SSHServer struct { trustedCAs []ssh.PublicKey // List of available commands - helpCommand *Command commands *radix.Tree listener net.Listener @@ -43,7 +42,7 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { conns: make(map[int]*session), } - cc := ssh.CertChecker{ + cc := &ssh.CertChecker{ IsUserAuthority: func(auth ssh.PublicKey) bool { for _, ca := range s.trustedCAs { if bytes.Equal(ca.Marshal(), auth.Marshal()) { @@ -77,10 +76,11 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { }, } + s.certChecker = cc s.config = &ssh.ServerConfig{ PublicKeyCallback: cc.Authenticate, - ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"), + ServerVersion: "SSH-2.0-Nebula???", } s.RegisterCommand(&Command{ diff --git a/sshd/session.go b/sshd/session.go index 87cc216..6e7a9cc 100644 --- a/sshd/session.go +++ b/sshd/session.go @@ -170,7 +170,6 @@ func (s *session) dispatchCommand(line string, w StringWriter) { } _ = execCommand(c, args[1:], w) - return } func (s *session) Close() { diff --git a/udp/conn.go b/udp/conn.go index 895b0df..965472a 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -30,15 +30,11 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader) { - return -} +func (NoopConn) ListenOut(_ EncReader) {} func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } -func (NoopConn) ReloadConfig(_ *config.C) { - return -} +func (NoopConn) ReloadConfig(_ *config.C) {} func (NoopConn) Close() error { return nil } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 06a4d53..be3e86e 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -33,7 +33,7 @@ func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, b if uc, ok := pc.(*net.UDPConn); ok { return &GenericConn{UDPConn: uc, l: l}, nil } - return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc) + return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc) } func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error { @@ -66,10 +66,6 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { return func() {} } -type rawMessage struct { - Len uint32 -} - func (u *GenericConn) ListenOut(r EncReader) { buffer := make([]byte, MTU)