mirror of
https://github.com/slackhq/nebula.git
synced 2026-02-14 08:44:24 +01:00
Rework some things into packages (#489)
This commit is contained in:
139
main.go
139
main.go
@@ -8,14 +8,16 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/sshd"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
type m map[string]interface{}
|
||||
|
||||
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
|
||||
|
||||
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
|
||||
defer func() {
|
||||
@@ -31,7 +33,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
|
||||
// Print the config if in test, the exit comes later
|
||||
if configTest {
|
||||
b, err := yaml.Marshal(config.Settings)
|
||||
b, err := yaml.Marshal(c.Settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -40,33 +42,33 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
l.Println(string(b))
|
||||
}
|
||||
|
||||
err := configLogger(config)
|
||||
err := configLogger(l, c)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to configure the logger", nil, err)
|
||||
}
|
||||
|
||||
config.RegisterReloadCallback(func(c *Config) {
|
||||
err := configLogger(c)
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
err := configLogger(l, c)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to configure the logger")
|
||||
}
|
||||
})
|
||||
|
||||
caPool, err := loadCAFromConfig(l, config)
|
||||
caPool, err := loadCAFromConfig(l, c)
|
||||
if err != nil {
|
||||
//The errors coming out of loadCA are already nicely formatted
|
||||
return nil, NewContextualError("Failed to load ca from config", nil, err)
|
||||
}
|
||||
l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
|
||||
|
||||
cs, err := NewCertStateFromConfig(config)
|
||||
cs, err := NewCertStateFromConfig(c)
|
||||
if err != nil {
|
||||
//The errors coming out of NewCertStateFromConfig are already nicely formatted
|
||||
return nil, NewContextualError("Failed to load certificate from config", nil, err)
|
||||
}
|
||||
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
||||
|
||||
fw, err := NewFirewallFromConfig(l, cs.certificate, config)
|
||||
fw, err := NewFirewallFromConfig(l, cs.certificate, c)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Error while loading firewall rules", nil, err)
|
||||
}
|
||||
@@ -74,20 +76,20 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
|
||||
// TODO: make sure mask is 4 bytes
|
||||
tunCidr := cs.certificate.Details.Ips[0]
|
||||
routes, err := parseRoutes(config, tunCidr)
|
||||
routes, err := parseRoutes(c, tunCidr)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Could not parse tun.routes", nil, err)
|
||||
}
|
||||
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
|
||||
unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
||||
}
|
||||
|
||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||
wireSSHReload(l, ssh, config)
|
||||
wireSSHReload(l, ssh, c)
|
||||
var sshStart func()
|
||||
if config.GetBool("sshd.enabled", false) {
|
||||
sshStart, err = configSSH(l, ssh, config)
|
||||
if c.GetBool("sshd.enabled", false) {
|
||||
sshStart, err = configSSH(l, ssh, c)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Error while configuring the sshd", nil, err)
|
||||
}
|
||||
@@ -101,7 +103,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
var routines int
|
||||
|
||||
// If `routines` is set, use that and ignore the specific values
|
||||
if routines = config.GetInt("routines", 0); routines != 0 {
|
||||
if routines = c.GetInt("routines", 0); routines != 0 {
|
||||
if routines < 1 {
|
||||
routines = 1
|
||||
}
|
||||
@@ -110,8 +112,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
}
|
||||
} else {
|
||||
// deprecated and undocumented
|
||||
tunQueues := config.GetInt("tun.routines", 1)
|
||||
udpQueues := config.GetInt("listen.routines", 1)
|
||||
tunQueues := c.GetInt("tun.routines", 1)
|
||||
udpQueues := c.GetInt("listen.routines", 1)
|
||||
if tunQueues > udpQueues {
|
||||
routines = tunQueues
|
||||
} else {
|
||||
@@ -125,8 +127,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
// EXPERIMENTAL
|
||||
// Intentionally not documented yet while we do more testing and determine
|
||||
// a good default value.
|
||||
conntrackCacheTimeout := config.GetDuration("firewall.conntrack.routine_cache_timeout", 0)
|
||||
if routines > 1 && !config.IsSet("firewall.conntrack.routine_cache_timeout") {
|
||||
conntrackCacheTimeout := c.GetDuration("firewall.conntrack.routine_cache_timeout", 0)
|
||||
if routines > 1 && !c.IsSet("firewall.conntrack.routine_cache_timeout") {
|
||||
// Use a different default if we are running with multiple routines
|
||||
conntrackCacheTimeout = 1 * time.Second
|
||||
}
|
||||
@@ -136,30 +138,30 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
|
||||
var tun Inside
|
||||
if !configTest {
|
||||
config.CatchHUP(ctx)
|
||||
c.CatchHUP(ctx)
|
||||
|
||||
switch {
|
||||
case config.GetBool("tun.disabled", false):
|
||||
tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l)
|
||||
case c.GetBool("tun.disabled", false):
|
||||
tun = newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||
case tunFd != nil:
|
||||
tun, err = newTunFromFd(
|
||||
l,
|
||||
*tunFd,
|
||||
tunCidr,
|
||||
config.GetInt("tun.mtu", DEFAULT_MTU),
|
||||
c.GetInt("tun.mtu", DEFAULT_MTU),
|
||||
routes,
|
||||
unsafeRoutes,
|
||||
config.GetInt("tun.tx_queue", 500),
|
||||
c.GetInt("tun.tx_queue", 500),
|
||||
)
|
||||
default:
|
||||
tun, err = newTun(
|
||||
l,
|
||||
config.GetString("tun.dev", ""),
|
||||
c.GetString("tun.dev", ""),
|
||||
tunCidr,
|
||||
config.GetInt("tun.mtu", DEFAULT_MTU),
|
||||
c.GetInt("tun.mtu", DEFAULT_MTU),
|
||||
routes,
|
||||
unsafeRoutes,
|
||||
config.GetInt("tun.tx_queue", 500),
|
||||
c.GetInt("tun.tx_queue", 500),
|
||||
routines > 1,
|
||||
)
|
||||
}
|
||||
@@ -176,16 +178,16 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
}()
|
||||
|
||||
// set up our UDP listener
|
||||
udpConns := make([]*udpConn, routines)
|
||||
port := config.GetInt("listen.port", 0)
|
||||
udpConns := make([]*udp.Conn, routines)
|
||||
port := c.GetInt("listen.port", 0)
|
||||
|
||||
if !configTest {
|
||||
for i := 0; i < routines; i++ {
|
||||
udpServer, err := NewListener(l, config.GetString("listen.host", "0.0.0.0"), port, routines > 1)
|
||||
udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64))
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||
}
|
||||
udpServer.reloadConfig(config)
|
||||
udpServer.ReloadConfig(c)
|
||||
udpConns[i] = udpServer
|
||||
|
||||
// If port is dynamic, discover it
|
||||
@@ -201,7 +203,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
|
||||
// Set up my internal host map
|
||||
var preferredRanges []*net.IPNet
|
||||
rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{})
|
||||
rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
|
||||
// First, check if 'preferred_ranges' is set and fallback to 'local_range'
|
||||
if len(rawPreferredRanges) > 0 {
|
||||
for _, rawPreferredRange := range rawPreferredRanges {
|
||||
@@ -216,7 +218,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
// local_range was superseded by preferred_ranges. If it is still present,
|
||||
// merge the local_range setting into preferred_ranges. We will probably
|
||||
// deprecate local_range and remove in the future.
|
||||
rawLocalRange := config.GetString("local_range", "")
|
||||
rawLocalRange := c.GetString("local_range", "")
|
||||
if rawLocalRange != "" {
|
||||
_, localRange, err := net.ParseCIDR(rawLocalRange)
|
||||
if err != nil {
|
||||
@@ -240,7 +242,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
|
||||
|
||||
hostMap.addUnsafeRoutes(&unsafeRoutes)
|
||||
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
|
||||
hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
|
||||
|
||||
l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
|
||||
|
||||
@@ -249,26 +251,26 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
go hostMap.Promoter(config.GetInt("promoter.interval"))
|
||||
*/
|
||||
|
||||
punchy := NewPunchyFromConfig(config)
|
||||
punchy := NewPunchyFromConfig(c)
|
||||
if punchy.Punch && !configTest {
|
||||
l.Info("UDP hole punching enabled")
|
||||
go hostMap.Punchy(ctx, udpConns[0])
|
||||
}
|
||||
|
||||
amLighthouse := config.GetBool("lighthouse.am_lighthouse", false)
|
||||
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
|
||||
|
||||
// fatal if am_lighthouse is enabled but we are using an ephemeral port
|
||||
if amLighthouse && (config.GetInt("listen.port", 0) == 0) {
|
||||
if amLighthouse && (c.GetInt("listen.port", 0) == 0) {
|
||||
return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
|
||||
}
|
||||
|
||||
// warn if am_lighthouse is enabled but upstream lighthouses exists
|
||||
rawLighthouseHosts := config.GetStringSlice("lighthouse.hosts", []string{})
|
||||
rawLighthouseHosts := c.GetStringSlice("lighthouse.hosts", []string{})
|
||||
if amLighthouse && len(rawLighthouseHosts) != 0 {
|
||||
l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
|
||||
}
|
||||
|
||||
lighthouseHosts := make([]uint32, len(rawLighthouseHosts))
|
||||
lighthouseHosts := make([]iputil.VpnIp, len(rawLighthouseHosts))
|
||||
for i, host := range rawLighthouseHosts {
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
@@ -277,7 +279,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
if !tunCidr.Contains(ip) {
|
||||
return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
|
||||
}
|
||||
lighthouseHosts[i] = ip2int(ip)
|
||||
lighthouseHosts[i] = iputil.Ip2VpnIp(ip)
|
||||
}
|
||||
|
||||
lightHouse := NewLightHouse(
|
||||
@@ -286,47 +288,48 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
tunCidr,
|
||||
lighthouseHosts,
|
||||
//TODO: change to a duration
|
||||
config.GetInt("lighthouse.interval", 10),
|
||||
c.GetInt("lighthouse.interval", 10),
|
||||
uint32(port),
|
||||
udpConns[0],
|
||||
punchy.Respond,
|
||||
punchy.Delay,
|
||||
config.GetBool("stats.lighthouse_metrics", false),
|
||||
c.GetBool("stats.lighthouse_metrics", false),
|
||||
)
|
||||
|
||||
remoteAllowList, err := config.GetRemoteAllowList("lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
|
||||
remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
|
||||
}
|
||||
lightHouse.SetRemoteAllowList(remoteAllowList)
|
||||
|
||||
localAllowList, err := config.GetLocalAllowList("lighthouse.local_allow_list")
|
||||
localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list")
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
|
||||
}
|
||||
lightHouse.SetLocalAllowList(localAllowList)
|
||||
|
||||
//TODO: Move all of this inside functions in lighthouse.go
|
||||
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
|
||||
vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
|
||||
if !tunCidr.Contains(vpnIp) {
|
||||
for k, v := range c.GetMap("static_host_map", map[interface{}]interface{}{}) {
|
||||
ip := net.ParseIP(fmt.Sprintf("%v", k))
|
||||
vpnIp := iputil.Ip2VpnIp(ip)
|
||||
if !tunCidr.Contains(ip) {
|
||||
return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
|
||||
}
|
||||
vals, ok := v.([]interface{})
|
||||
if ok {
|
||||
for _, v := range vals {
|
||||
ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
|
||||
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||
}
|
||||
lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
|
||||
lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
|
||||
}
|
||||
} else {
|
||||
ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
|
||||
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||
}
|
||||
lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
|
||||
lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -336,16 +339,16 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
}
|
||||
|
||||
var messageMetrics *MessageMetrics
|
||||
if config.GetBool("stats.message_metrics", false) {
|
||||
if c.GetBool("stats.message_metrics", false) {
|
||||
messageMetrics = newMessageMetrics()
|
||||
} else {
|
||||
messageMetrics = newMessageMetricsOnlyRecvError()
|
||||
}
|
||||
|
||||
handshakeConfig := HandshakeConfig{
|
||||
tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
|
||||
retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries),
|
||||
triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
|
||||
tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
|
||||
retries: c.GetInt("handshakes.retries", DefaultHandshakeRetries),
|
||||
triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
|
||||
|
||||
messageMetrics: messageMetrics,
|
||||
}
|
||||
@@ -358,36 +361,35 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
//handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{})
|
||||
|
||||
serveDns := false
|
||||
if config.GetBool("lighthouse.serve_dns", false) {
|
||||
if config.GetBool("lighthouse.am_lighthouse", false) {
|
||||
if c.GetBool("lighthouse.serve_dns", false) {
|
||||
if c.GetBool("lighthouse.am_lighthouse", false) {
|
||||
serveDns = true
|
||||
} else {
|
||||
l.Warn("DNS server refusing to run because this host is not a lighthouse.")
|
||||
}
|
||||
}
|
||||
|
||||
checkInterval := config.GetInt("timers.connection_alive_interval", 5)
|
||||
pendingDeletionInterval := config.GetInt("timers.pending_deletion_interval", 10)
|
||||
checkInterval := c.GetInt("timers.connection_alive_interval", 5)
|
||||
pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
|
||||
ifConfig := &InterfaceConfig{
|
||||
HostMap: hostMap,
|
||||
Inside: tun,
|
||||
Outside: udpConns[0],
|
||||
certState: cs,
|
||||
Cipher: config.GetString("cipher", "aes"),
|
||||
Cipher: c.GetString("cipher", "aes"),
|
||||
Firewall: fw,
|
||||
ServeDns: serveDns,
|
||||
HandshakeManager: handshakeManager,
|
||||
lightHouse: lightHouse,
|
||||
checkInterval: checkInterval,
|
||||
pendingDeletionInterval: pendingDeletionInterval,
|
||||
DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false),
|
||||
DropMulticast: config.GetBool("tun.drop_multicast", false),
|
||||
UDPBatchSize: config.GetInt("listen.batch", 64),
|
||||
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
|
||||
DropMulticast: c.GetBool("tun.drop_multicast", false),
|
||||
routines: routines,
|
||||
MessageMetrics: messageMetrics,
|
||||
version: buildVersion,
|
||||
caPool: caPool,
|
||||
disconnectInvalid: config.GetBool("pki.disconnect_invalid", false),
|
||||
disconnectInvalid: c.GetBool("pki.disconnect_invalid", false),
|
||||
|
||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||
l: l,
|
||||
@@ -413,7 +415,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
// I don't want to make this initial commit too far-reaching though
|
||||
ifce.writers = udpConns
|
||||
|
||||
ifce.RegisterConfigChangeCallbacks(config)
|
||||
ifce.RegisterConfigChangeCallbacks(c)
|
||||
|
||||
go handshakeManager.Run(ctx, ifce)
|
||||
go lightHouse.LhUpdateWorker(ctx, ifce)
|
||||
@@ -421,7 +423,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
|
||||
// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
|
||||
// a context so that they can exit when the context is Done.
|
||||
statsStart, err := startStats(l, config, buildVersion, configTest)
|
||||
statsStart, err := startStats(l, c, buildVersion, configTest)
|
||||
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to start stats emitter", nil, err)
|
||||
}
|
||||
@@ -431,7 +434,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
}
|
||||
|
||||
//TODO: check if we _should_ be emitting stats
|
||||
go ifce.emitStats(ctx, config.GetDuration("stats.interval", time.Second*10))
|
||||
go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
|
||||
|
||||
attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
|
||||
|
||||
@@ -439,7 +442,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
||||
var dnsStart func()
|
||||
if amLighthouse && serveDns {
|
||||
l.Debugln("Starting dns server")
|
||||
dnsStart = dnsMain(l, hostMap, config)
|
||||
dnsStart = dnsMain(l, hostMap, c)
|
||||
}
|
||||
|
||||
return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil
|
||||
|
||||
Reference in New Issue
Block a user