From 353ad1f27193b92d3bb7d6696cac5239fa1728e5 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Fri, 13 Feb 2026 11:10:40 -0600 Subject: [PATCH] firewall: icmp no longer requires a port spec (#1609) --- examples/config.yml | 2 +- firewall.go | 138 +++++++++++++++++++++++++------------------- firewall_test.go | 20 +++++-- 3 files changed, 95 insertions(+), 65 deletions(-) diff --git a/examples/config.yml b/examples/config.yml index f81baab6..1f9dc2a4 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -382,8 +382,8 @@ firewall: # Rules are comprised of a protocol, port, and one or more of host, group, or CIDR # Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr) # - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available). - # code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any` # proto: `any`, `tcp`, `udp`, or `icmp` + # a port specification is ignored if proto is `icmp` # host: `any` or a literal hostname, ie `test-host` # group: `any` or a literal group name, ie `default-group` # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass diff --git a/firewall.go b/firewall.go index 45dc0691..72119e0e 100644 --- a/firewall.go +++ b/firewall.go @@ -249,20 +249,6 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew // AddRule properly creates the in memory rule structure for a firewall table. func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { - // We need this rule string because we generate a hash. Removing this will break firewall reload. - ruleString := fmt.Sprintf( - "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", - incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, - ) - f.rules += ruleString + "\n" - - direction := "incoming" - if !incoming { - direction = "outgoing" - } - f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}). - Info("Firewall rule added") - var ( ft *FirewallTable fp firewallPort @@ -280,6 +266,12 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort case firewall.ProtoUDP: fp = ft.UDP case firewall.ProtoICMP, firewall.ProtoICMPv6: + //ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided + if startPort != firewall.PortAny { + f.l.WithField("startPort", startPort).Warn("ignoring port specification for ICMP firewall rule") + } + startPort = firewall.PortAny + endPort = firewall.PortAny fp = ft.ICMP case firewall.ProtoAny: fp = ft.AnyProto @@ -287,6 +279,20 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort return fmt.Errorf("unknown protocol %v", proto) } + // We need this rule string because we generate a hash. Removing this will break firewall reload. + ruleString := fmt.Sprintf( + "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", + incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, + ) + f.rules += ruleString + "\n" + + direction := "incoming" + if !incoming { + direction = "outgoing" + } + f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}). + Info("Firewall rule added") + return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha) } @@ -349,24 +355,31 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw sPort = r.Port } - startPort, endPort, err := parsePort(sPort) - if err != nil { - return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err) - } - var proto uint8 + var startPort, endPort int32 switch r.Proto { case "any": proto = firewall.ProtoAny + startPort, endPort, err = parsePort(sPort) case "tcp": proto = firewall.ProtoTCP + startPort, endPort, err = parsePort(sPort) case "udp": proto = firewall.ProtoUDP + startPort, endPort, err = parsePort(sPort) case "icmp": proto = firewall.ProtoICMP + startPort = firewall.PortAny + endPort = firewall.PortAny + if sPort != "" { + l.WithField("port", sPort).Warn("ignoring port specification for ICMP firewall rule") + } default: return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) } + if err != nil { + return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err) + } if r.Cidr != "" && r.Cidr != "any" { _, err = netip.ParsePrefix(r.Cidr) @@ -660,6 +673,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer return false } + // this branch is here to catch traffic from FirewallTable.Any.match and FirewallTable.ICMP.match + if p.Protocol == firewall.ProtoICMP || p.Protocol == firewall.ProtoICMPv6 { + // port numbers are re-used for connection tracking of ICMP, + // but we don't want to actually filter on them. + return fp[firewall.PortAny].match(p, c, caPool) + } + var port int32 if p.Fragment { @@ -1018,54 +1038,56 @@ func (r *rule) sanity() error { } } + if r.Code != "" { + return fmt.Errorf("code specified as [%s]. Support for 'code' will be dropped in a future release, as it has never been functional", r.Code) + } + //todo alert on cidr-any return nil } -func parsePort(s string) (startPort, endPort int32, err error) { +func parsePort(s string) (int32, int32, error) { + var err error + const notAPort int32 = -2 if s == "any" { - startPort = firewall.PortAny - endPort = firewall.PortAny - - } else if s == "fragment" { - startPort = firewall.PortFragment - endPort = firewall.PortFragment - - } else if strings.Contains(s, `-`) { - sPorts := strings.SplitN(s, `-`, 2) - sPorts[0] = strings.Trim(sPorts[0], " ") - sPorts[1] = strings.Trim(sPorts[1], " ") - - if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" { - return 0, 0, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s) - } - - rStartPort, err := strconv.Atoi(sPorts[0]) - if err != nil { - return 0, 0, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0]) - } - - rEndPort, err := strconv.Atoi(sPorts[1]) - if err != nil { - return 0, 0, fmt.Errorf("ending range was not a number; `%s`", sPorts[1]) - } - - startPort = int32(rStartPort) - endPort = int32(rEndPort) - - if startPort == firewall.PortAny { - endPort = firewall.PortAny - } - - } else { + return firewall.PortAny, firewall.PortAny, nil + } + if s == "fragment" { + return firewall.PortFragment, firewall.PortFragment, nil + } + if !strings.Contains(s, `-`) { rPort, err := strconv.Atoi(s) if err != nil { - return 0, 0, fmt.Errorf("was not a number; `%s`", s) + return notAPort, notAPort, fmt.Errorf("was not a number; `%s`", s) } - startPort = int32(rPort) - endPort = startPort + return int32(rPort), int32(rPort), nil } - return + sPorts := strings.SplitN(s, `-`, 2) + for i := range sPorts { + sPorts[i] = strings.Trim(sPorts[i], " ") + } + if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" { + return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s) + } + + rStartPort, err := strconv.Atoi(sPorts[0]) + if err != nil { + return notAPort, notAPort, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0]) + } + + rEndPort, err := strconv.Atoi(sPorts[1]) + if err != nil { + return notAPort, notAPort, fmt.Errorf("ending range was not a number; `%s`", sPorts[1]) + } + + startPort := int32(rStartPort) + endPort := int32(rEndPort) + + if startPort == firewall.PortAny { + endPort = firewall.PortAny + } + + return startPort, endPort, nil } diff --git a/firewall_test.go b/firewall_test.go index 1df62a81..934a90a4 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -87,9 +87,10 @@ func TestFirewall_AddRule(t *testing.T) { fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) - assert.Nil(t, fw.InRules.ICMP[1].Any.Any) - assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) - assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") + //no matter what port is given for icmp, it should end up as "any" + assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any) + assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups) + assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) @@ -919,11 +920,11 @@ func TestNewFirewallFromConfig(t *testing.T) { // Test code/port error conf = config.NewC(l) - conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") - conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") @@ -973,7 +974,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) + + // Test adding icmp rule no port + conf = config.NewC(l) + mf = &mockFirewall{} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}} + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding any rule conf = config.NewC(l)