diff --git a/firewall.go b/firewall.go index 45dc069..d900036 100644 --- a/firewall.go +++ b/firewall.go @@ -23,16 +23,15 @@ import ( ) type FirewallInterface interface { - AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error + AddRule(forward, incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error } type conn struct { Expires time.Time // Time when this conntrack entry will expire - // record why the original connection passed the firewall, so we can re-validate - // after ruleset changes. Note, rulesVersion is a uint16 so that these two - // fields pack for free after the uint32 above + // record why the original connection passed the firewall, so we can re-validate after ruleset changes. incoming bool + forward bool rulesVersion uint16 } @@ -40,8 +39,10 @@ type conn struct { type Firewall struct { Conntrack *FirewallConntrack - InRules *FirewallTable - OutRules *FirewallTable + InRules *FirewallTable + OutRules *FirewallTable + ForwardInRules *FirewallTable + ForwardOutRules *FirewallTable InSendReject bool OutSendReject bool @@ -54,7 +55,7 @@ type Firewall struct { // routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate. // The vpn addresses are a full bit match while the unsafe networks only match the prefix - routableNetworks *bart.Lite + routableNetworks *bart.Table[NetworkType] // assignedNetworks is a list of vpn networks assigned to us in the certificate. assignedNetworks []netip.Prefix @@ -149,17 +150,16 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D tmax = defaultTimeout } - routableNetworks := new(bart.Lite) + routableNetworks := new(bart.Table[NetworkType]) var assignedNetworks []netip.Prefix for _, network := range c.Networks() { - nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) - routableNetworks.Insert(nprefix) + routableNetworks.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), NetworkTypeVPN) assignedNetworks = append(assignedNetworks, network) } hasUnsafeNetworks := false for _, n := range c.UnsafeNetworks() { - routableNetworks.Insert(n) + routableNetworks.Insert(n, NetworkTypeUnsafe) hasUnsafeNetworks = true } @@ -170,6 +170,8 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D }, InRules: newFirewallTable(), OutRules: newFirewallTable(), + ForwardInRules: newFirewallTable(), + ForwardOutRules: newFirewallTable(), TCPTimeout: tcpTimeout, UDPTimeout: UDPTimeout, DefaultTimeout: defaultTimeout, @@ -212,6 +214,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false) + //TODO: do we also need firewall.forward_inbound_action and firewall.forward_outbound_action? inboundAction := c.GetString("firewall.inbound_action", "drop") switch inboundAction { case "reject": @@ -234,12 +237,22 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew fw.OutSendReject = false } - err := AddFirewallRulesFromConfig(l, false, c, fw) + err := AddFirewallRulesFromConfig(l, false, false, c, fw) if err != nil { return nil, err } - err = AddFirewallRulesFromConfig(l, true, c, fw) + err = AddFirewallRulesFromConfig(l, true, false, c, fw) + if err != nil { + return nil, err + } + + err = AddFirewallRulesFromConfig(l, false, true, c, fw) + if err != nil { + return nil, err + } + + err = AddFirewallRulesFromConfig(l, true, true, c, fw) if err != nil { return nil, err } @@ -248,11 +261,11 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew } // AddRule properly creates the in memory rule structure for a firewall table. -func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { +func (f *Firewall) AddRule(forward, incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { // We need this rule string because we generate a hash. Removing this will break firewall reload. ruleString := fmt.Sprintf( - "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", - incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, + "forward: %v, incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", + forward, incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, ) f.rules += ruleString + "\n" @@ -260,8 +273,12 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort if !incoming { direction = "outgoing" } - f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}). - Info("Firewall rule added") + + fields := m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha} + if forward { + fields["forward"] = true + } + f.l.WithField("firewallRule", fields).Info("Firewall rule added") var ( ft *FirewallTable @@ -269,9 +286,18 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort ) if incoming { - ft = f.InRules + if forward { + ft = f.ForwardInRules + } else { + ft = f.InRules + } + } else { - ft = f.OutRules + if forward { + ft = f.ForwardOutRules + } else { + ft = f.OutRules + } } switch proto { @@ -308,12 +334,21 @@ func (f *Firewall) GetRuleHashes() string { return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10) } -func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error { +func AddFirewallRulesFromConfig(l *logrus.Logger, forward, inbound bool, c *config.C, fw FirewallInterface) error { var table string if inbound { - table = "firewall.inbound" + if forward { + table = "firewall.forward_inbound" + } else { + table = "firewall.inbound" + } + } else { - table = "firewall.outbound" + if forward { + table = "firewall.forward_outbound" + } else { + table = "firewall.outbound" + } } r := c.Get(table) @@ -386,7 +421,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw l.Warnf("%s rule #%v; %s", table, i, warning) } - err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha) + err = fw.AddRule(forward, inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha) if err != nil { return fmt.Errorf("%s rule #%v; `%s`", table, i, err) } @@ -409,6 +444,9 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * return nil } + var remoteNetworkType NetworkType + var ok bool + // Make sure remote address matches nebula certificate, and determine how to treat it if h.networks == nil { // Simple case: Certificate has one address and no unsafe networks @@ -416,13 +454,14 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } + remoteNetworkType = NetworkTypeVPN } else { - nwType, ok := h.networks.Lookup(fp.RemoteAddr) + remoteNetworkType, ok = h.networks.Lookup(fp.RemoteAddr) if !ok { f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } - switch nwType { + switch remoteNetworkType { case NetworkTypeVPN: break // nothing special case NetworkTypeVPNPeer: @@ -437,14 +476,27 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // Make sure we are supposed to be handling this local ip address - if !f.routableNetworks.Contains(fp.LocalAddr) { + localNetworkType, ok := f.routableNetworks.Lookup(fp.LocalAddr) + if !ok { f.metrics(incoming).droppedLocalAddr.Inc(1) return ErrInvalidLocalIP } - table := f.OutRules + useForward := remoteNetworkType == NetworkTypeUnsafe || localNetworkType == NetworkTypeUnsafe + + var table *FirewallTable if incoming { - table = f.InRules + if useForward { + table = f.ForwardInRules + } else { + table = f.InRules + } + } else { + if useForward { + table = f.ForwardOutRules + } else { + table = f.OutRules + } } // We now know which firewall table to check against @@ -454,12 +506,13 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // We always want to conntrack since it is a faster operation - f.addConn(fp, incoming) + f.addConn(fp, useForward, incoming) return nil } func (f *Firewall) metrics(incoming bool) firewallMetrics { + //TODO: need forward metrics too if incoming { return f.incomingMetrics } else { @@ -499,7 +552,6 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, } c, ok := conntrack.Conns[fp] - if !ok { conntrack.Unlock() return false @@ -508,9 +560,19 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, if c.rulesVersion != f.rulesVersion { // This conntrack entry was for an older rule set, validate // it still passes with the current rule set - table := f.OutRules + var table *FirewallTable if c.incoming { - table = f.InRules + if c.forward { + table = f.ForwardInRules + } else { + table = f.InRules + } + } else { + if c.forward { + table = f.ForwardOutRules + } else { + table = f.OutRules + } } // We now know which firewall table to check against @@ -519,6 +581,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, h.logger(f.l). WithField("fwPacket", fp). WithField("incoming", c.incoming). + WithField("forward", c.forward). WithField("rulesVersion", f.rulesVersion). WithField("oldRulesVersion", c.rulesVersion). Debugln("dropping old conntrack entry, does not match new ruleset") @@ -532,6 +595,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, h.logger(f.l). WithField("fwPacket", fp). WithField("incoming", c.incoming). + WithField("forward", c.forward). WithField("rulesVersion", f.rulesVersion). WithField("oldRulesVersion", c.rulesVersion). Debugln("keeping old conntrack entry, does match new ruleset") @@ -558,7 +622,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, return true } -func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { +func (f *Firewall) addConn(fp firewall.Packet, forward, incoming bool) { var timeout time.Duration c := &conn{} @@ -581,6 +645,7 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { // Record which rulesVersion allowed this connection, so we can retest after // firewall reload c.incoming = incoming + c.forward = forward c.rulesVersion = f.rulesVersion c.Expires = time.Now().Add(timeout) conntrack.Conns[fp] = c @@ -937,6 +1002,7 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) { r.Code = toString("code", m) r.Proto = toString("proto", m) r.Host = toString("host", m) + //TODO: create an alias to remote_cidr and deprecate cidr? r.Cidr = toString("cidr", m) r.LocalCidr = toString("local_cidr", m) r.CAName = toString("ca_name", m) diff --git a/firewall_test.go b/firewall_test.go index 1df62a8..3ac903b 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -73,65 +73,65 @@ func TestFirewall_AddRule(t *testing.T) { ti6, err := netip.ParsePrefix("fd12::34/128") require.NoError(t, err) - require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) + require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", "")) + require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", "")) + require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", "")) + require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", "")) + require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) anyIp, err := netip.ParsePrefix("0.0.0.0/0") require.NoError(t, err) - require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", "")) + require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any) table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1")) assert.True(t, table.Any) @@ -142,7 +142,7 @@ func TestFirewall_AddRule(t *testing.T) { anyIp6, err := netip.ParsePrefix("::/0") require.NoError(t, err) - require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", "")) + require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any) table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9")) assert.True(t, table.Any) @@ -150,29 +150,29 @@ func TestFirewall_AddRule(t *testing.T) { assert.False(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", "")) + require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", "")) + require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", "")) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", "")) + require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", "")) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", "")) + require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", "")) - require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", "")) + require.Error(t, fw.AddRule(false, true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", "")) + require.Error(t, fw.AddRule(false, true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", "")) } func TestFirewall_Drop(t *testing.T) { @@ -208,7 +208,7 @@ func TestFirewall_Drop(t *testing.T) { h.buildNetworks(myVpnNetworksTable, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() // Drop outbound @@ -227,28 +227,28 @@ func TestFirewall_Drop(t *testing.T) { // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } @@ -287,7 +287,7 @@ func TestFirewall_DropV6(t *testing.T) { h.buildNetworks(myVpnNetworksTable, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() // Drop outbound @@ -306,28 +306,28 @@ func TestFirewall_DropV6(t *testing.T) { // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } @@ -532,7 +532,7 @@ func TestFirewall_Drop2(t *testing.T) { h1.buildNetworks(myVpnNetworksTable, c1.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups @@ -612,8 +612,8 @@ func TestFirewall_Drop3(t *testing.T) { h3.buildNetworks(myVpnNetworksTable, c3.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", "")) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match @@ -627,7 +627,7 @@ func TestFirewall_Drop3(t *testing.T) { // Test a remote address match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) } @@ -665,7 +665,7 @@ func TestFirewall_Drop3V6(t *testing.T) { // Test a remote address match fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) cp := cert.NewCAPool() - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } @@ -704,7 +704,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { h.buildNetworks(myVpnNetworksTable, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() // Drop outbound @@ -717,7 +717,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -726,7 +726,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -765,7 +765,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) { fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "")) cp := cert.NewCAPool() // Packet spoofed by `c1`. Note that the remote addr is not a valid one. @@ -958,28 +958,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf := config.NewC(l) mf := &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding rule with cidr @@ -987,14 +987,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall) // Test adding rule with local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall) // Test adding rule with cidr ipv6 @@ -1002,75 +1002,75 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall) // Test adding rule with any cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall) // Test adding rule with junk cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}} - require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") + require.EqualError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") // Test adding rule with local_cidr ipv6 conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall) // Test adding rule with any local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall) // Test adding rule with junk local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}} - require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") + require.EqualError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall) // Test Add error @@ -1078,7 +1078,7 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { mf = &mockFirewall{} mf.nextCallReturn = errors.New("test error") conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} - require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") + require.EqualError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf), "firewall.inbound rule #0; `test error`") } func TestFirewall_convertRule(t *testing.T) { @@ -1251,7 +1251,7 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { myVpnNetworksTable.Insert(prefix) } fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) + require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) return testsetup{ c: c, @@ -1332,7 +1332,7 @@ func TestFirewall_Drop_EnforceIPMatch(t *testing.T) { tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3") tc.err = ErrNoMatchingRule tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off - require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", "")) + require.NoError(t, unsafeSetup.fw.AddRule(true, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", "")) tc.err = nil tc.Test(t, unsafeSetup.fw) //should pass }) @@ -1356,7 +1356,7 @@ type mockFirewall struct { nextCallReturn error } -func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error { +func (mf *mockFirewall) AddRule(forward, incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error { mf.lastCall = addRuleCall{ incoming: incoming, proto: proto,