mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-11 05:53:58 +01:00
Merge remote-tracking branch 'origin/master' into mutex-debug
This commit is contained in:
commit
a83f0ca470
11
README.md
11
README.md
@ -118,6 +118,17 @@ To build nebula for a specific platform (ex, Windows):
|
||||
|
||||
See the [Makefile](Makefile) for more details on build targets
|
||||
|
||||
## Curve P256 and BoringCrypto
|
||||
|
||||
The default curve used for cryptographic handshakes and signatures is Curve25519. This is the recommended setting for most users. If your deployment has certain compliance requirements, you have the option of creating your CA using `nebula-cert ca -curve P256` to use NIST Curve P256. The CA will then sign certificates using ECDSA P256, and any hosts using these certificates will use P256 for ECDH handshakes.
|
||||
|
||||
In addition, Nebula can be built using the [BoringCrypto GOEXPERIMENT](https://github.com/golang/go/blob/go1.20/src/crypto/internal/boring/README.md) by running either of the following make targets:
|
||||
|
||||
make bin-boringcrypto
|
||||
make release-boringcrypto
|
||||
|
||||
This is not the recommended default deployment, but may be useful based on your compliance requirements.
|
||||
|
||||
## Credits
|
||||
|
||||
Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang.
|
||||
|
||||
12
SECURITY.md
Normal file
12
SECURITY.md
Normal file
@ -0,0 +1,12 @@
|
||||
Security Policy
|
||||
===============
|
||||
|
||||
Reporting a Vulnerability
|
||||
-------------------------
|
||||
|
||||
If you believe you have found a security vulnerability with Nebula, please let
|
||||
us know right away. We will investigate all reports and do our best to quickly
|
||||
fix valid issues.
|
||||
|
||||
You can submit your report on [HackerOne](https://hackerone.com/slack) and our
|
||||
security team will respond as soon as possible.
|
||||
@ -13,8 +13,14 @@ type Node struct {
|
||||
value interface{}
|
||||
}
|
||||
|
||||
type entry struct {
|
||||
CIDR *net.IPNet
|
||||
Value *interface{}
|
||||
}
|
||||
|
||||
type Tree4 struct {
|
||||
root *Node
|
||||
list []entry
|
||||
}
|
||||
|
||||
const (
|
||||
@ -24,6 +30,7 @@ const (
|
||||
func NewTree4() *Tree4 {
|
||||
tree := new(Tree4)
|
||||
tree.root = &Node{}
|
||||
tree.list = []entry{}
|
||||
return tree
|
||||
}
|
||||
|
||||
@ -53,6 +60,15 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
|
||||
|
||||
// We already have this range so update the value
|
||||
if next != nil {
|
||||
addCIDR := cidr.String()
|
||||
for i, v := range tree.list {
|
||||
if addCIDR == v.CIDR.String() {
|
||||
tree.list = append(tree.list[:i], tree.list[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
|
||||
node.value = val
|
||||
return
|
||||
}
|
||||
@ -74,9 +90,10 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
|
||||
|
||||
// Final node marks our cidr, set the value
|
||||
node.value = val
|
||||
tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
|
||||
}
|
||||
|
||||
// Finds the first match, which may be the least specific
|
||||
// Contains finds the first match, which may be the least specific
|
||||
func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
|
||||
bit := startbit
|
||||
node := tree.root
|
||||
@ -99,7 +116,7 @@ func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
|
||||
return value
|
||||
}
|
||||
|
||||
// Finds the most specific match
|
||||
// MostSpecificContains finds the most specific match
|
||||
func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
|
||||
bit := startbit
|
||||
node := tree.root
|
||||
@ -121,7 +138,7 @@ func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
|
||||
return value
|
||||
}
|
||||
|
||||
// Finds the most specific match
|
||||
// Match finds the most specific match
|
||||
func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
|
||||
bit := startbit
|
||||
node := tree.root
|
||||
@ -143,3 +160,8 @@ func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// List will return all CIDRs and their current values. Do not modify the contents!
|
||||
func (tree *Tree4) List() []entry {
|
||||
return tree.list
|
||||
}
|
||||
|
||||
@ -8,6 +8,20 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCIDRTree_List(t *testing.T) {
|
||||
tree := NewTree4()
|
||||
tree.AddCIDR(Parse("1.0.0.0/16"), "1")
|
||||
tree.AddCIDR(Parse("1.0.0.0/8"), "2")
|
||||
tree.AddCIDR(Parse("1.0.0.0/16"), "3")
|
||||
tree.AddCIDR(Parse("1.0.0.0/16"), "4")
|
||||
list := tree.List()
|
||||
assert.Len(t, list, 2)
|
||||
assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
|
||||
assert.Equal(t, "2", *list[0].Value)
|
||||
assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
|
||||
assert.Equal(t, "4", *list[1].Value)
|
||||
}
|
||||
|
||||
func TestCIDRTree_Contains(t *testing.T) {
|
||||
tree := NewTree4()
|
||||
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
|
||||
|
||||
@ -47,7 +47,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
||||
Signature: []byte{1, 2, 1, 2, 1, 3},
|
||||
}
|
||||
|
||||
remotes := NewRemoteList()
|
||||
remotes := NewRemoteList(nil)
|
||||
remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
|
||||
remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
|
||||
hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{
|
||||
|
||||
@ -223,6 +223,10 @@ tun:
|
||||
# metric: 100
|
||||
# install: true
|
||||
|
||||
# On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of
|
||||
# in nebula configuration files. Default false, not reloadable.
|
||||
#use_system_route_table: false
|
||||
|
||||
# TODO
|
||||
# Configure logging level
|
||||
logging:
|
||||
@ -301,7 +305,8 @@ firewall:
|
||||
# host: `any` or a literal hostname, ie `test-host`
|
||||
# group: `any` or a literal group name, ie `default-group`
|
||||
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
||||
# cidr: a CIDR, `0.0.0.0/0` is any.
|
||||
# cidr: a remote CIDR, `0.0.0.0/0` is any.
|
||||
# local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
|
||||
# ca_name: An issuing CA name
|
||||
# ca_sha: An issuing CA shasum
|
||||
|
||||
|
||||
101
firewall.go
101
firewall.go
@ -25,7 +25,7 @@ const tcpACK = 0x10
|
||||
const tcpFIN = 0x01
|
||||
|
||||
type FirewallInterface interface {
|
||||
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error
|
||||
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
@ -106,11 +106,12 @@ type FirewallCA struct {
|
||||
}
|
||||
|
||||
type FirewallRule struct {
|
||||
// Any makes Hosts, Groups, and CIDR irrelevant
|
||||
Any bool
|
||||
Hosts map[string]struct{}
|
||||
Groups [][]string
|
||||
CIDR *cidr.Tree4
|
||||
// Any makes Hosts, Groups, CIDR and LocalCIDR irrelevant
|
||||
Any bool
|
||||
Hosts map[string]struct{}
|
||||
Groups [][]string
|
||||
CIDR *cidr.Tree4
|
||||
LocalCIDR *cidr.Tree4
|
||||
}
|
||||
|
||||
// Even though ports are uint16, int32 maps are faster for lookup
|
||||
@ -218,18 +219,22 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
|
||||
}
|
||||
|
||||
// AddRule properly creates the in memory rule structure for a firewall table.
|
||||
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
|
||||
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
|
||||
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
|
||||
// https://github.com/golang/go/issues/14131
|
||||
sIp := ""
|
||||
if ip != nil {
|
||||
sIp = ip.String()
|
||||
}
|
||||
lIp := ""
|
||||
if localIp != nil {
|
||||
lIp = localIp.String()
|
||||
}
|
||||
|
||||
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
||||
ruleString := fmt.Sprintf(
|
||||
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
|
||||
incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha,
|
||||
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
|
||||
incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
|
||||
)
|
||||
f.rules += ruleString + "\n"
|
||||
|
||||
@ -237,7 +242,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
||||
if !incoming {
|
||||
direction = "outgoing"
|
||||
}
|
||||
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
|
||||
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
|
||||
Info("Firewall rule added")
|
||||
|
||||
var (
|
||||
@ -264,7 +269,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
||||
return fmt.Errorf("unknown protocol %v", proto)
|
||||
}
|
||||
|
||||
return fp.addRule(startPort, endPort, groups, host, ip, caName, caSha)
|
||||
return fp.addRule(startPort, endPort, groups, host, ip, localIp, caName, caSha)
|
||||
}
|
||||
|
||||
// GetRuleHash returns a hash representation of all inbound and outbound rules
|
||||
@ -302,8 +307,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
||||
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
|
||||
}
|
||||
|
||||
if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.CAName == "" && r.CASha == "" {
|
||||
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, ca_name, or ca_sha must be provided", table, i)
|
||||
if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
|
||||
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
|
||||
}
|
||||
|
||||
if len(r.Groups) > 0 {
|
||||
@ -355,7 +360,15 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
||||
}
|
||||
}
|
||||
|
||||
err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, r.CAName, r.CASha)
|
||||
var localCidr *net.IPNet
|
||||
if r.LocalCidr != "" {
|
||||
_, localCidr, err = net.ParseCIDR(r.LocalCidr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
|
||||
}
|
||||
@ -595,7 +608,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
|
||||
return false
|
||||
}
|
||||
|
||||
func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
|
||||
func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
|
||||
if startPort > endPort {
|
||||
return fmt.Errorf("start port was lower than end port")
|
||||
}
|
||||
@ -608,7 +621,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
|
||||
}
|
||||
}
|
||||
|
||||
if err := fp[i].addRule(groups, host, ip, caName, caSha); err != nil {
|
||||
if err := fp[i].addRule(groups, host, ip, localIp, caName, caSha); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -639,12 +652,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
|
||||
return fp[firewall.PortAny].match(p, c, caPool)
|
||||
}
|
||||
|
||||
func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
|
||||
func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
|
||||
fr := func() *FirewallRule {
|
||||
return &FirewallRule{
|
||||
Hosts: make(map[string]struct{}),
|
||||
Groups: make([][]string, 0),
|
||||
CIDR: cidr.NewTree4(),
|
||||
Hosts: make(map[string]struct{}),
|
||||
Groups: make([][]string, 0),
|
||||
CIDR: cidr.NewTree4(),
|
||||
LocalCIDR: cidr.NewTree4(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -653,14 +667,14 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
|
||||
fc.Any = fr()
|
||||
}
|
||||
|
||||
return fc.Any.addRule(groups, host, ip)
|
||||
return fc.Any.addRule(groups, host, ip, localIp)
|
||||
}
|
||||
|
||||
if caSha != "" {
|
||||
if _, ok := fc.CAShas[caSha]; !ok {
|
||||
fc.CAShas[caSha] = fr()
|
||||
}
|
||||
err := fc.CAShas[caSha].addRule(groups, host, ip)
|
||||
err := fc.CAShas[caSha].addRule(groups, host, ip, localIp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -670,7 +684,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
|
||||
if _, ok := fc.CANames[caName]; !ok {
|
||||
fc.CANames[caName] = fr()
|
||||
}
|
||||
err := fc.CANames[caName].addRule(groups, host, ip)
|
||||
err := fc.CANames[caName].addRule(groups, host, ip, localIp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -702,17 +716,18 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
|
||||
return fc.CANames[s.Details.Name].match(p, c)
|
||||
}
|
||||
|
||||
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
|
||||
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, localIp *net.IPNet) error {
|
||||
if fr.Any {
|
||||
return nil
|
||||
}
|
||||
|
||||
if fr.isAny(groups, host, ip) {
|
||||
if fr.isAny(groups, host, ip, localIp) {
|
||||
fr.Any = true
|
||||
// If it's any we need to wipe out any pre-existing rules to save on memory
|
||||
fr.Groups = make([][]string, 0)
|
||||
fr.Hosts = make(map[string]struct{})
|
||||
fr.CIDR = cidr.NewTree4()
|
||||
fr.LocalCIDR = cidr.NewTree4()
|
||||
} else {
|
||||
if len(groups) > 0 {
|
||||
fr.Groups = append(fr.Groups, groups)
|
||||
@ -725,13 +740,17 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err
|
||||
if ip != nil {
|
||||
fr.CIDR.AddCIDR(ip, struct{}{})
|
||||
}
|
||||
|
||||
if localIp != nil {
|
||||
fr.LocalCIDR.AddCIDR(localIp, struct{}{})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
|
||||
if len(groups) == 0 && host == "" && ip == nil {
|
||||
func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPNet) bool {
|
||||
if len(groups) == 0 && host == "" && ip == nil && localIp == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -749,6 +768,10 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
|
||||
return true
|
||||
}
|
||||
|
||||
if localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0)) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@ -790,20 +813,25 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
|
||||
return true
|
||||
}
|
||||
|
||||
if fr.LocalCIDR != nil && fr.LocalCIDR.Contains(p.LocalIP) != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// No host, group, or cidr matched, bye bye
|
||||
return false
|
||||
}
|
||||
|
||||
type rule struct {
|
||||
Port string
|
||||
Code string
|
||||
Proto string
|
||||
Host string
|
||||
Group string
|
||||
Groups []string
|
||||
Cidr string
|
||||
CAName string
|
||||
CASha string
|
||||
Port string
|
||||
Code string
|
||||
Proto string
|
||||
Host string
|
||||
Group string
|
||||
Groups []string
|
||||
Cidr string
|
||||
LocalCidr string
|
||||
CAName string
|
||||
CASha string
|
||||
}
|
||||
|
||||
func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
|
||||
@ -827,6 +855,7 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
|
||||
r.Proto = toString("proto", m)
|
||||
r.Host = toString("host", m)
|
||||
r.Cidr = toString("cidr", m)
|
||||
r.LocalCidr = toString("local_cidr", m)
|
||||
r.CAName = toString("ca_name", m)
|
||||
r.CASha = toString("ca_sha", m)
|
||||
|
||||
|
||||
147
firewall_test.go
147
firewall_test.go
@ -69,67 +69,75 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
|
||||
_, ti, _ := net.ParseCIDR("1.2.3.4/32")
|
||||
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
|
||||
// An empty rule is any
|
||||
assert.True(t, fw.InRules.TCP[1].Any.Any)
|
||||
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
|
||||
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
|
||||
assert.False(t, fw.InRules.UDP[1].Any.Any)
|
||||
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
|
||||
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
|
||||
assert.False(t, fw.InRules.ICMP[1].Any.Any)
|
||||
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
||||
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, "", ""))
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
|
||||
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
|
||||
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
|
||||
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
|
||||
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
|
||||
|
||||
// Set any and clear fields
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
|
||||
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
|
||||
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
|
||||
|
||||
// run twice just to make sure
|
||||
//TODO: these ANY rules should clear the CA firewall portion
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||
|
||||
// Test error conditions
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
|
||||
assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, "", ""))
|
||||
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", ""))
|
||||
assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", ""))
|
||||
}
|
||||
|
||||
func TestFirewall_Drop(t *testing.T) {
|
||||
@ -169,7 +177,7 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
h.CreateRemoteCIDR(&c)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
@ -188,28 +196,28 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
|
||||
// ensure signer doesn't get in the way of group checks
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad"))
|
||||
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caSha doesn't drop on match
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum"))
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
||||
|
||||
// ensure ca name doesn't get in the way of group checks
|
||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", ""))
|
||||
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caName doesn't drop on match
|
||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", ""))
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
@ -219,11 +227,11 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
}
|
||||
|
||||
_, n, _ := net.ParseCIDR("172.1.1.1/32")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, n, "", "")
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
b.Run("fail on proto", func(b *testing.B) {
|
||||
@ -291,7 +299,20 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
}
|
||||
})
|
||||
|
||||
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
|
||||
b.Run("pass on local ip", func(b *testing.B) {
|
||||
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||
c := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
Name: "good-host",
|
||||
},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
|
||||
}
|
||||
})
|
||||
|
||||
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
|
||||
|
||||
b.Run("pass on ip with any port", func(b *testing.B) {
|
||||
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||
@ -305,6 +326,19 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass on local ip with any port", func(b *testing.B) {
|
||||
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||
c := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
Name: "good-host",
|
||||
},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFirewall_Drop2(t *testing.T) {
|
||||
@ -356,7 +390,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
h1.CreateRemoteCIDR(&c1)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// h1/c1 lacks the proper groups
|
||||
@ -438,8 +472,8 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
h3.CreateRemoteCIDR(&c3)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha"))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// c1 should pass because host match
|
||||
@ -489,7 +523,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
h.CreateRemoteCIDR(&c)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
@ -502,7 +536,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
|
||||
oldFw := fw
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", ""))
|
||||
fw.Conntrack = oldFw.Conntrack
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
@ -511,7 +545,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
|
||||
oldFw = fw
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", ""))
|
||||
fw.Conntrack = oldFw.Conntrack
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
@ -653,7 +687,7 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
|
||||
|
||||
// Test code/port error
|
||||
conf = config.NewC(l)
|
||||
@ -677,6 +711,12 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
|
||||
|
||||
// Test local_cidr parse error
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh")
|
||||
|
||||
// Test both group and groups
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
||||
@ -691,63 +731,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
mf := &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
|
||||
|
||||
// Test adding udp rule
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
|
||||
|
||||
// Test adding icmp rule
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
|
||||
|
||||
// Test adding any rule
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
|
||||
|
||||
// Test adding rule with cidr
|
||||
cidr := &net.IPNet{net.ParseIP("10.0.0.0").To4(), net.IPv4Mask(255, 0, 0, 0)}
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall)
|
||||
|
||||
// Test adding rule with local_cidr
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall)
|
||||
|
||||
// Test adding rule with ca_sha
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall)
|
||||
|
||||
// Test adding rule with ca_name
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall)
|
||||
|
||||
// Test single group
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
|
||||
|
||||
// Test single groups
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
|
||||
|
||||
// Test multiple AND groups
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall)
|
||||
|
||||
// Test Add error
|
||||
conf = config.NewC(l)
|
||||
@ -892,6 +947,7 @@ type addRuleCall struct {
|
||||
groups []string
|
||||
host string
|
||||
ip *net.IPNet
|
||||
localIp *net.IPNet
|
||||
caName string
|
||||
caSha string
|
||||
}
|
||||
@ -901,7 +957,7 @@ type mockFirewall struct {
|
||||
nextCallReturn error
|
||||
}
|
||||
|
||||
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
|
||||
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
|
||||
mf.lastCall = addRuleCall{
|
||||
incoming: incoming,
|
||||
proto: proto,
|
||||
@ -910,6 +966,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
|
||||
groups: groups,
|
||||
host: host,
|
||||
ip: ip,
|
||||
localIp: localIp,
|
||||
caName: caName,
|
||||
caSha: caSha,
|
||||
}
|
||||
|
||||
@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
||||
assert.False(t, initCalled)
|
||||
assert.Same(t, i, i2)
|
||||
|
||||
i.remotes = NewRemoteList()
|
||||
i.remotes = NewRemoteList(nil)
|
||||
i.HandshakeReady = true
|
||||
|
||||
// Adding something to pending should not affect the main hostmap
|
||||
|
||||
175
lighthouse.go
175
lighthouse.go
@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@ -33,6 +34,7 @@ type netIpAndPort struct {
|
||||
type LightHouse struct {
|
||||
//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
|
||||
sync.RWMutex //Because we concurrently read and write to our maps
|
||||
ctx context.Context
|
||||
amLighthouse bool
|
||||
myVpnIp iputil.VpnIp
|
||||
myVpnZeros iputil.VpnIp
|
||||
@ -82,7 +84,7 @@ type LightHouse struct {
|
||||
|
||||
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
|
||||
// addrMap should be nil unless this is during a config reload
|
||||
func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) {
|
||||
func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) {
|
||||
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
|
||||
nebulaPort := uint32(c.GetInt("listen.port", 0))
|
||||
if amLighthouse && nebulaPort == 0 {
|
||||
@ -100,6 +102,7 @@ func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet,
|
||||
|
||||
ones, _ := myVpnNet.Mask.Size()
|
||||
h := LightHouse{
|
||||
ctx: ctx,
|
||||
amLighthouse: amLighthouse,
|
||||
myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP),
|
||||
myVpnZeros: iputil.VpnIp(32 - ones),
|
||||
@ -258,7 +261,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
||||
}
|
||||
|
||||
//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
|
||||
if initial || c.HasChanged("static_host_map") {
|
||||
if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") {
|
||||
staticList := make(map[iputil.VpnIp]struct{})
|
||||
err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
|
||||
if err != nil {
|
||||
@ -268,9 +271,19 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
||||
lh.staticList.Store(&staticList)
|
||||
if !initial {
|
||||
//TODO: we should remove any remote list entries for static hosts that were removed/modified?
|
||||
lh.l.Info("static_host_map has changed")
|
||||
if c.HasChanged("static_host_map") {
|
||||
lh.l.Info("static_host_map has changed")
|
||||
}
|
||||
if c.HasChanged("static_map.cadence") {
|
||||
lh.l.Info("static_map.cadence has changed")
|
||||
}
|
||||
if c.HasChanged("static_map.network") {
|
||||
lh.l.Info("static_map.network has changed")
|
||||
}
|
||||
if c.HasChanged("static_map.lookup_timeout") {
|
||||
lh.l.Info("static_map.lookup_timeout has changed")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if initial || c.HasChanged("lighthouse.hosts") {
|
||||
@ -344,7 +357,48 @@ func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap ma
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStaticMapCadence(c *config.C) (time.Duration, error) {
|
||||
cadence := c.GetString("static_map.cadence", "30s")
|
||||
d, err := time.ParseDuration(cadence)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func getStaticMapLookupTimeout(c *config.C) (time.Duration, error) {
|
||||
lookupTimeout := c.GetString("static_map.lookup_timeout", "250ms")
|
||||
d, err := time.ParseDuration(lookupTimeout)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func getStaticMapNetwork(c *config.C) (string, error) {
|
||||
network := c.GetString("static_map.network", "ip4")
|
||||
if network != "ip" && network != "ip4" && network != "ip6" {
|
||||
return "", fmt.Errorf("static_map.network must be one of ip, ip4, or ip6")
|
||||
}
|
||||
return network, nil
|
||||
}
|
||||
|
||||
func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error {
|
||||
d, err := getStaticMapCadence(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
network, err := getStaticMapNetwork(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lookup_timeout, err := getStaticMapLookupTimeout(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
shm := c.GetMap("static_host_map", map[interface{}]interface{}{})
|
||||
i := 0
|
||||
|
||||
@ -360,21 +414,17 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
|
||||
|
||||
vpnIp := iputil.Ip2VpnIp(rip)
|
||||
vals, ok := v.([]interface{})
|
||||
if ok {
|
||||
for _, v := range vals {
|
||||
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
|
||||
if err != nil {
|
||||
return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
|
||||
}
|
||||
lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList)
|
||||
}
|
||||
if !ok {
|
||||
vals = []interface{}{v}
|
||||
}
|
||||
remoteAddrs := []string{}
|
||||
for _, v := range vals {
|
||||
remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v))
|
||||
}
|
||||
|
||||
} else {
|
||||
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
|
||||
if err != nil {
|
||||
return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
|
||||
}
|
||||
lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList)
|
||||
err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i++
|
||||
}
|
||||
@ -482,30 +532,47 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
|
||||
// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
|
||||
// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
|
||||
// NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it
|
||||
func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, staticList map[iputil.VpnIp]struct{}) {
|
||||
func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error {
|
||||
lh.Lock()
|
||||
am := lh.unlockedGetRemoteList(vpnIp)
|
||||
am.Lock()
|
||||
defer am.Unlock()
|
||||
ctx := lh.ctx
|
||||
lh.Unlock()
|
||||
|
||||
if ipv4 := toAddr.IP.To4(); ipv4 != nil {
|
||||
to := NewIp4AndPort(ipv4, uint32(toAddr.Port))
|
||||
if !lh.unlockedShouldAddV4(vpnIp, to) {
|
||||
return
|
||||
}
|
||||
am.unlockedPrependV4(lh.myVpnIp, to)
|
||||
hr, err := NewHostnameResults(ctx, lh.l, d, network, timeout, toAddrs, func() {
|
||||
// This callback runs whenever the DNS hostname resolver finds a different set of IP's
|
||||
// in its resolution for hostnames.
|
||||
am.Lock()
|
||||
defer am.Unlock()
|
||||
am.shouldRebuild = true
|
||||
})
|
||||
if err != nil {
|
||||
return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
|
||||
}
|
||||
am.unlockedSetHostnamesResults(hr)
|
||||
|
||||
} else {
|
||||
to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))
|
||||
if !lh.unlockedShouldAddV6(vpnIp, to) {
|
||||
return
|
||||
for _, addrPort := range hr.GetIPs() {
|
||||
|
||||
switch {
|
||||
case addrPort.Addr().Is4():
|
||||
to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
|
||||
if !lh.unlockedShouldAddV4(vpnIp, to) {
|
||||
continue
|
||||
}
|
||||
am.unlockedPrependV4(lh.myVpnIp, to)
|
||||
case addrPort.Addr().Is6():
|
||||
to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
|
||||
if !lh.unlockedShouldAddV6(vpnIp, to) {
|
||||
continue
|
||||
}
|
||||
am.unlockedPrependV6(lh.myVpnIp, to)
|
||||
}
|
||||
am.unlockedPrependV6(lh.myVpnIp, to)
|
||||
}
|
||||
|
||||
// Mark it as static in the caller provided map
|
||||
staticList[vpnIp] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addCalculatedRemotes adds any calculated remotes based on the
|
||||
@ -545,12 +612,42 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
|
||||
func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
|
||||
am, ok := lh.addrMap[vpnIp]
|
||||
if !ok {
|
||||
am = NewRemoteList()
|
||||
am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) })
|
||||
lh.addrMap[vpnIp] = am
|
||||
}
|
||||
return am
|
||||
}
|
||||
|
||||
func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool {
|
||||
switch {
|
||||
case to.Is4():
|
||||
ipBytes := to.As4()
|
||||
ip := iputil.Ip2VpnIp(ipBytes[:])
|
||||
allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip)
|
||||
if lh.l.Level >= logrus.TraceLevel {
|
||||
lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
|
||||
}
|
||||
if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) {
|
||||
return false
|
||||
}
|
||||
case to.Is6():
|
||||
ipBytes := to.As16()
|
||||
|
||||
hi := binary.BigEndian.Uint64(ipBytes[:8])
|
||||
lo := binary.BigEndian.Uint64(ipBytes[8:])
|
||||
allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo)
|
||||
if lh.l.Level >= logrus.TraceLevel {
|
||||
lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow")
|
||||
}
|
||||
|
||||
// We don't check our vpn network here because nebula does not support ipv6 on the inside
|
||||
if !allow {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// unlockedShouldAddV4 checks if to is allowed by our allow list
|
||||
func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
|
||||
allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
|
||||
@ -609,6 +706,14 @@ func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
|
||||
return &ipp
|
||||
}
|
||||
|
||||
func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
|
||||
v4Addr := ip.As4()
|
||||
return &Ip4AndPort{
|
||||
Ip: binary.BigEndian.Uint32(v4Addr[:]),
|
||||
Port: uint32(port),
|
||||
}
|
||||
}
|
||||
|
||||
func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
|
||||
return &Ip6AndPort{
|
||||
Hi: binary.BigEndian.Uint64(ip[:8]),
|
||||
@ -617,6 +722,14 @@ func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
|
||||
}
|
||||
}
|
||||
|
||||
func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
|
||||
ip6Addr := ip.As16()
|
||||
return &Ip6AndPort{
|
||||
Hi: binary.BigEndian.Uint64(ip6Addr[:8]),
|
||||
Lo: binary.BigEndian.Uint64(ip6Addr[8:]),
|
||||
Port: uint32(port),
|
||||
}
|
||||
}
|
||||
func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
|
||||
ip := ipp.Ip
|
||||
return udp.NewAddr(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
@ -53,14 +54,14 @@ func Test_lhStaticMapping(t *testing.T) {
|
||||
c := config.NewC(l)
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
|
||||
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
|
||||
_, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
|
||||
_, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
lh2 := "10.128.0.3"
|
||||
c = config.NewC(l)
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
|
||||
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
|
||||
_, err = NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
|
||||
_, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
|
||||
assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
|
||||
}
|
||||
|
||||
@ -69,14 +70,14 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||
_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0")
|
||||
|
||||
c := config.NewC(l)
|
||||
lh, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
|
||||
if !assert.NoError(b, err) {
|
||||
b.Fatal()
|
||||
}
|
||||
|
||||
hAddr := udp.NewAddrFromString("4.5.6.7:12345")
|
||||
hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
|
||||
lh.addrMap[3] = NewRemoteList()
|
||||
lh.addrMap[3] = NewRemoteList(nil)
|
||||
lh.addrMap[3].unlockedSetV4(
|
||||
3,
|
||||
3,
|
||||
@ -89,7 +90,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||
|
||||
rAddr := udp.NewAddrFromString("1.2.2.3:12345")
|
||||
rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
|
||||
lh.addrMap[2] = NewRemoteList()
|
||||
lh.addrMap[2] = NewRemoteList(nil)
|
||||
lh.addrMap[2].unlockedSetV4(
|
||||
3,
|
||||
3,
|
||||
@ -162,7 +163,7 @@ func TestLighthouse_Memory(t *testing.T) {
|
||||
c := config.NewC(l)
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
|
||||
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
|
||||
lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
lhh := lh.NewRequestHandler()
|
||||
|
||||
@ -238,7 +239,7 @@ func TestLighthouse_reload(t *testing.T) {
|
||||
c := config.NewC(l)
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
|
||||
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
|
||||
lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}}
|
||||
|
||||
2
main.go
2
main.go
@ -226,7 +226,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
*/
|
||||
|
||||
punchy := NewPunchyFromConfig(l, c)
|
||||
lightHouse, err := NewLightHouseFromConfig(l, c, tunCidr, udpConns[0], punchy)
|
||||
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
|
||||
switch {
|
||||
case errors.As(err, &util.ContextualError{}):
|
||||
return nil, err
|
||||
|
||||
@ -35,6 +35,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *
|
||||
c.GetInt("tun.mtu", DefaultMTU),
|
||||
routes,
|
||||
c.GetInt("tun.tx_queue", 500),
|
||||
c.GetBool("tun.use_system_route_table", false),
|
||||
)
|
||||
|
||||
default:
|
||||
@ -46,6 +47,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *
|
||||
routes,
|
||||
c.GetInt("tun.tx_queue", 500),
|
||||
routines > 1,
|
||||
c.GetBool("tun.use_system_route_table", false),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -22,7 +22,7 @@ type tun struct {
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) {
|
||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) {
|
||||
routeTree, err := makeRouteTree(l, routes, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -41,7 +41,7 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
|
||||
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTun not supported in Android")
|
||||
}
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ type ifreqMTU struct {
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
|
||||
func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
|
||||
routeTree, err := makeRouteTree(l, routes, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -170,7 +170,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
||||
return
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) {
|
||||
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||
}
|
||||
|
||||
|
||||
@ -38,11 +38,11 @@ func (t *tun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) {
|
||||
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||
}
|
||||
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
|
||||
routeTree, err := makeRouteTree(l, routes, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@ -23,11 +23,11 @@ type tun struct {
|
||||
routeTree *cidr.Tree4
|
||||
}
|
||||
|
||||
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
|
||||
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTun not supported in iOS")
|
||||
}
|
||||
|
||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) {
|
||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) {
|
||||
routeTree, err := makeRouteTree(l, routes, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@ -4,11 +4,13 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -26,9 +28,13 @@ type tun struct {
|
||||
MaxMTU int
|
||||
DefaultMTU int
|
||||
TXQueueLen int
|
||||
Routes []Route
|
||||
routeTree *cidr.Tree4
|
||||
l *logrus.Logger
|
||||
|
||||
Routes []Route
|
||||
routeTree atomic.Pointer[cidr.Tree4]
|
||||
routeChan chan struct{}
|
||||
useSystemRoutes bool
|
||||
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
type ifReq struct {
|
||||
@ -63,7 +69,7 @@ type ifreqQLEN struct {
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int) (*tun, error) {
|
||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, useSystemRoutes bool) (*tun, error) {
|
||||
routeTree, err := makeRouteTree(l, routes, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -71,7 +77,7 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in
|
||||
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
return &tun{
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
fd: int(file.Fd()),
|
||||
Device: "tun0",
|
||||
@ -79,12 +85,14 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in
|
||||
DefaultMTU: defaultMTU,
|
||||
TXQueueLen: txQueueLen,
|
||||
Routes: routes,
|
||||
routeTree: routeTree,
|
||||
useSystemRoutes: useSystemRoutes,
|
||||
l: l,
|
||||
}, nil
|
||||
}
|
||||
t.routeTree.Store(routeTree)
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool) (*tun, error) {
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool, useSystemRoutes bool) (*tun, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -119,7 +127,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tun{
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
fd: int(file.Fd()),
|
||||
Device: name,
|
||||
@ -128,9 +136,11 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
|
||||
DefaultMTU: defaultMTU,
|
||||
TXQueueLen: txQueueLen,
|
||||
Routes: routes,
|
||||
routeTree: routeTree,
|
||||
useSystemRoutes: useSystemRoutes,
|
||||
l: l,
|
||||
}, nil
|
||||
}
|
||||
t.routeTree.Store(routeTree)
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
@ -152,7 +162,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
r := t.routeTree.MostSpecificContains(ip)
|
||||
r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
if r != nil {
|
||||
return r.(iputil.VpnIp)
|
||||
}
|
||||
@ -183,16 +193,20 @@ func (t *tun) Write(b []byte) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (t tun) deviceBytes() (o [16]byte) {
|
||||
func (t *tun) deviceBytes() (o [16]byte) {
|
||||
for i, c := range t.Device {
|
||||
o[i] = byte(c)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (t tun) Activate() error {
|
||||
func (t *tun) Activate() error {
|
||||
devName := t.deviceBytes()
|
||||
|
||||
if t.useSystemRoutes {
|
||||
t.watchRoutes()
|
||||
}
|
||||
|
||||
var addr, mask [4]byte
|
||||
|
||||
copy(addr[:], t.cidr.IP.To4())
|
||||
@ -318,7 +332,7 @@ func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
func (t tun) advMSS(r Route) int {
|
||||
func (t *tun) advMSS(r Route) int {
|
||||
mtu := r.MTU
|
||||
if r.MTU == 0 {
|
||||
mtu = t.DefaultMTU
|
||||
@ -330,3 +344,83 @@ func (t tun) advMSS(r Route) int {
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (t *tun) watchRoutes() {
|
||||
rch := make(chan netlink.RouteUpdate)
|
||||
doneChan := make(chan struct{})
|
||||
|
||||
if err := netlink.RouteSubscribe(rch, doneChan); err != nil {
|
||||
t.l.WithError(err).Errorf("failed to subscribe to system route changes")
|
||||
return
|
||||
}
|
||||
|
||||
t.routeChan = doneChan
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case r := <-rch:
|
||||
t.updateRoutes(r)
|
||||
case <-doneChan:
|
||||
// netlink.RouteSubscriber will close the rch for us
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||
if r.Gw == nil {
|
||||
// Not a gateway route, ignore
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route")
|
||||
return
|
||||
}
|
||||
|
||||
if !t.cidr.Contains(r.Gw) {
|
||||
// Gateway isn't in our overlay network, ignore
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||
return
|
||||
}
|
||||
|
||||
if x := r.Dst.IP.To4(); x == nil {
|
||||
// Nebula only handles ipv4 on the overlay currently
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4")
|
||||
return
|
||||
}
|
||||
|
||||
newTree := cidr.NewTree4()
|
||||
if r.Type == unix.RTM_NEWROUTE {
|
||||
for _, oldR := range t.routeTree.Load().List() {
|
||||
newTree.AddCIDR(oldR.CIDR, oldR.Value)
|
||||
}
|
||||
|
||||
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
|
||||
newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw))
|
||||
|
||||
} else {
|
||||
gw := iputil.Ip2VpnIp(r.Gw)
|
||||
for _, oldR := range t.routeTree.Load().List() {
|
||||
if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && *oldR.Value != nil && (*oldR.Value).(iputil.VpnIp) == gw {
|
||||
// This is the record to delete
|
||||
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
|
||||
continue
|
||||
}
|
||||
|
||||
newTree.AddCIDR(oldR.CIDR, oldR.Value)
|
||||
}
|
||||
}
|
||||
|
||||
t.routeTree.Store(newTree)
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
if t.routeChan != nil {
|
||||
close(t.routeChan)
|
||||
}
|
||||
|
||||
if t.ReadWriteCloser != nil {
|
||||
t.ReadWriteCloser.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -7,19 +7,19 @@ import "testing"
|
||||
|
||||
var runAdvMSSTests = []struct {
|
||||
name string
|
||||
tun tun
|
||||
tun *tun
|
||||
r Route
|
||||
expected int
|
||||
}{
|
||||
// Standard case, default MTU is the device max MTU
|
||||
{"default", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
|
||||
{"default-min", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
|
||||
{"default-low", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
|
||||
{"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
|
||||
{"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
|
||||
{"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
|
||||
|
||||
// Case where we have a route MTU set higher than the default
|
||||
{"route", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
|
||||
{"route-min", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
|
||||
{"route-high", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
|
||||
{"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
|
||||
{"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
|
||||
{"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
|
||||
}
|
||||
|
||||
func TestTunAdvMSS(t *testing.T) {
|
||||
|
||||
@ -25,7 +25,7 @@ type TestTun struct {
|
||||
TxPackets chan []byte // Packets transmitted outside by nebula
|
||||
}
|
||||
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*TestTun, error) {
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool, _ bool) (*TestTun, error) {
|
||||
routeTree, err := makeRouteTree(l, routes, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -42,7 +42,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*TestTun, error) {
|
||||
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*TestTun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported")
|
||||
}
|
||||
|
||||
|
||||
@ -14,11 +14,11 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (Device, error) {
|
||||
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (Device, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||
}
|
||||
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (Device, error) {
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (Device, error) {
|
||||
useWintun := true
|
||||
if err := checkWinTunExists(); err != nil {
|
||||
l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
|
||||
|
||||
170
remote_list.go
170
remote_list.go
@ -2,10 +2,16 @@ package nebula
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
@ -55,6 +61,132 @@ type cacheV6 struct {
|
||||
reported []*Ip6AndPort
|
||||
}
|
||||
|
||||
type hostnamePort struct {
|
||||
name string
|
||||
port uint16
|
||||
}
|
||||
|
||||
type hostnamesResults struct {
|
||||
hostnames []hostnamePort
|
||||
network string
|
||||
lookupTimeout time.Duration
|
||||
stop chan struct{}
|
||||
l *logrus.Logger
|
||||
ips atomic.Pointer[map[netip.AddrPort]struct{}]
|
||||
}
|
||||
|
||||
func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
|
||||
r := &hostnamesResults{
|
||||
hostnames: make([]hostnamePort, len(hostPorts)),
|
||||
network: network,
|
||||
lookupTimeout: timeout,
|
||||
stop: make(chan (struct{})),
|
||||
l: l,
|
||||
}
|
||||
|
||||
// Fastrack IP addresses to ensure they're immediately available for use.
|
||||
// DNS lookups for hostnames that aren't hardcoded IP's will happen in a background goroutine.
|
||||
performBackgroundLookup := false
|
||||
ips := map[netip.AddrPort]struct{}{}
|
||||
for idx, hostPort := range hostPorts {
|
||||
|
||||
rIp, sPort, err := net.SplitHostPort(hostPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iPort, err := strconv.Atoi(sPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.hostnames[idx] = hostnamePort{name: rIp, port: uint16(iPort)}
|
||||
addr, err := netip.ParseAddr(rIp)
|
||||
if err != nil {
|
||||
// This address is a hostname, not an IP address
|
||||
performBackgroundLookup = true
|
||||
continue
|
||||
}
|
||||
|
||||
// Save the IP address immediately
|
||||
ips[netip.AddrPortFrom(addr, uint16(iPort))] = struct{}{}
|
||||
}
|
||||
r.ips.Store(&ips)
|
||||
|
||||
// Time for the DNS lookup goroutine
|
||||
if performBackgroundLookup {
|
||||
ticker := time.NewTicker(d)
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
netipAddrs := map[netip.AddrPort]struct{}{}
|
||||
for _, hostPort := range r.hostnames {
|
||||
timeoutCtx, timeoutCancel := context.WithTimeout(ctx, r.lookupTimeout)
|
||||
addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name)
|
||||
timeoutCancel()
|
||||
if err != nil {
|
||||
l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host")
|
||||
continue
|
||||
}
|
||||
for _, a := range addrs {
|
||||
netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{}
|
||||
}
|
||||
}
|
||||
origSet := r.ips.Load()
|
||||
different := false
|
||||
for a := range *origSet {
|
||||
if _, ok := netipAddrs[a]; !ok {
|
||||
different = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !different {
|
||||
for a := range netipAddrs {
|
||||
if _, ok := (*origSet)[a]; !ok {
|
||||
different = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if different {
|
||||
l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list")
|
||||
r.ips.Store(&netipAddrs)
|
||||
onUpdate()
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-r.stop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (hr *hostnamesResults) Cancel() {
|
||||
if hr != nil {
|
||||
hr.stop <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (hr *hostnamesResults) GetIPs() []netip.AddrPort {
|
||||
var retSlice []netip.AddrPort
|
||||
if hr != nil {
|
||||
p := hr.ips.Load()
|
||||
if p != nil {
|
||||
for k := range *p {
|
||||
retSlice = append(retSlice, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
return retSlice
|
||||
}
|
||||
|
||||
// RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos.
|
||||
// It serves as a local cache of query replies, host update notifications, and locally learned addresses
|
||||
type RemoteList struct {
|
||||
@ -72,6 +204,9 @@ type RemoteList struct {
|
||||
// For learned addresses, this is the vpnIp that sent the packet
|
||||
cache map[iputil.VpnIp]*cache
|
||||
|
||||
hr *hostnamesResults
|
||||
shouldAdd func(netip.Addr) bool
|
||||
|
||||
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
|
||||
// They should not be tried again during a handshake
|
||||
badRemotes []*udp.Addr
|
||||
@ -81,14 +216,21 @@ type RemoteList struct {
|
||||
}
|
||||
|
||||
// NewRemoteList creates a new empty RemoteList
|
||||
func NewRemoteList() *RemoteList {
|
||||
func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
|
||||
return &RemoteList{
|
||||
addrs: make([]*udp.Addr, 0),
|
||||
relays: make([]*iputil.VpnIp, 0),
|
||||
cache: make(map[iputil.VpnIp]*cache),
|
||||
addrs: make([]*udp.Addr, 0),
|
||||
relays: make([]*iputil.VpnIp, 0),
|
||||
cache: make(map[iputil.VpnIp]*cache),
|
||||
shouldAdd: shouldAdd,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
|
||||
// Cancel any existing hostnamesResults DNS goroutine to release resources
|
||||
r.hr.Cancel()
|
||||
r.hr = hr
|
||||
}
|
||||
|
||||
// Len locks and reports the size of the deduplicated address list
|
||||
// The deduplication work may need to occur here, so you must pass preferredRanges
|
||||
func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
|
||||
@ -437,6 +579,26 @@ func (r *RemoteList) unlockedCollect() {
|
||||
}
|
||||
}
|
||||
|
||||
dnsAddrs := r.hr.GetIPs()
|
||||
for _, addr := range dnsAddrs {
|
||||
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
|
||||
switch {
|
||||
case addr.Addr().Is4():
|
||||
v4 := addr.Addr().As4()
|
||||
addrs = append(addrs, &udp.Addr{
|
||||
IP: v4[:],
|
||||
Port: addr.Port(),
|
||||
})
|
||||
case addr.Addr().Is6():
|
||||
v6 := addr.Addr().As16()
|
||||
addrs = append(addrs, &udp.Addr{
|
||||
IP: v6[:],
|
||||
Port: addr.Port(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.addrs = addrs
|
||||
r.relays = relays
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func TestRemoteList_Rebuild(t *testing.T) {
|
||||
rl := NewRemoteList()
|
||||
rl := NewRemoteList(nil)
|
||||
rl.unlockedSetV4(
|
||||
0,
|
||||
0,
|
||||
@ -102,7 +102,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
|
||||
}
|
||||
|
||||
func BenchmarkFullRebuild(b *testing.B) {
|
||||
rl := NewRemoteList()
|
||||
rl := NewRemoteList(nil)
|
||||
rl.unlockedSetV4(
|
||||
0,
|
||||
0,
|
||||
@ -167,7 +167,7 @@ func BenchmarkFullRebuild(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkSortRebuild(b *testing.B) {
|
||||
rl := NewRemoteList()
|
||||
rl := NewRemoteList(nil)
|
||||
rl.unlockedSetV4(
|
||||
0,
|
||||
0,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user