1818#include " Firewall.h"
1919
2020int Firewall::protocol_from_string_to_int (const std::string &proto) {
21- if (proto == " TCP" )
21+ if (proto == " TCP" || proto == " tcp " )
2222 return IPPROTO_TCP;
23- if (proto == " UDP" )
23+ if (proto == " UDP" || proto == " udp " )
2424 return IPPROTO_UDP;
25- if (proto == " ICMP" )
25+ if (proto == " ICMP" || proto == " icmp " )
2626 return IPPROTO_ICMP;
27- if (proto == " GRE" )
27+ if (proto == " GRE" || proto == " gre " )
2828 return IPPROTO_GRE;
2929 else
3030 throw std::runtime_error (" Protocol not supported." );
@@ -43,13 +43,13 @@ void Firewall::replaceAll(std::string &str, const std::string &from,
4343}
4444
4545int ChainRule::protocol_from_string_to_int (const std::string &proto) {
46- if (proto == " TCP" )
46+ if (proto == " TCP" || proto == " tcp " )
4747 return IPPROTO_TCP;
48- if (proto == " UDP" )
48+ if (proto == " UDP" || proto == " udp " )
4949 return IPPROTO_UDP;
50- if (proto == " ICMP" )
50+ if (proto == " ICMP" || proto == " icmp " )
5151 return IPPROTO_ICMP;
52- if (proto == " GRE" )
52+ if (proto == " GRE" || proto == " gre " )
5353 return IPPROTO_GRE;
5454
5555 throw std::runtime_error (" Protocol not supported." );
@@ -207,15 +207,14 @@ int ChainRule::ActionEnum_to_int(const ActionEnum &action) {
207207}
208208
209209static int ChainRuleConntrackEnum_to_int (const ConntrackstatusEnum &status) {
210- // 0 is reserved for "wildcard"
211210 if (status == ConntrackstatusEnum::NEW) {
212- return 1 ;
211+ return 0 ;
213212 } else if (status == ConntrackstatusEnum::ESTABLISHED) {
214- return 2 ;
213+ return 1 ;
215214 } else if (status == ConntrackstatusEnum::RELATED) {
216- return 3 ;
215+ return 2 ;
217216 } else if (status == ConntrackstatusEnum::INVALID) {
218- return 4 ;
217+ return 3 ;
219218 }
220219}
221220
@@ -482,46 +481,40 @@ bool Chain::portFromRulesToMap(
482481}
483482
484483bool Chain::conntrackFromRulesToMap (
485- std::map<uint8_t , std::vector<uint64_t >> &statusMap,
486- const std::vector<std::shared_ptr<ChainRule>> &rules) {
487- std::vector<uint16_t > dontCareRules;
484+ std::map<uint8_t , std::vector<uint64_t >> &statusMap,
485+ const std::vector<std::shared_ptr<ChainRule>> &rules) {
486+ std::vector<uint8_t > statesVector ({NEW, ESTABLISHED, RELATED, INVALID});
487+ uint32_t rule_id;
488+ uint8_t rule_state;
489+ bool conntrackRulePresent = false ;
488490
489- uint32_t ruleId;
490- uint8_t status;
491491 for (auto const &rule : rules) {
492492 try {
493- ruleId = rule->getId ();
494- status = ChainRuleConntrackEnum_to_int (rule->getConntrack ());
495-
496- } catch (std::runtime_error re) {
497- // Not set: don't care rule.
498- dontCareRules.push_back (ruleId);
499- continue ;
500- }
501-
502- auto it = statusMap.find (status);
503- if (it == statusMap.end ()) {
504- // First entry
505- std::vector<uint64_t > bitVector (
506- FROM_NRULES_TO_NELEMENTS (Firewall::maxRules));
507- SET_BIT (bitVector[ruleId / 63 ], ruleId % 63 );
508- statusMap.insert (
509- std::pair<uint8_t , std::vector<uint64_t >>(status, bitVector));
510- } else {
511- SET_BIT ((it->second )[ruleId / 63 ], ruleId % 63 );
493+ rule_state = ChainRuleConntrackEnum_to_int (rule->getConntrack ());
494+ conntrackRulePresent = true ;
495+ } catch (...) {
512496 }
513497 }
514- // Don't care rules are in all entries. Anyway, this loop is useless if there
515- // are no rules at all requiring matching on this field.
516- if (statusMap.size () != 0 && dontCareRules.size () != 0 ) {
498+
499+ if (!conntrackRulePresent)
500+ return false ;
501+
502+ for (uint8_t state : statesVector) {
517503 std::vector<uint64_t > bitVector (
518- FROM_NRULES_TO_NELEMENTS (Firewall::maxRules));
519- statusMap.insert (std::pair<uint8_t , std::vector<uint64_t >>(0 , bitVector));
520- for (auto const &ruleNumber : dontCareRules) {
521- for (auto &statusMapEntry : statusMap) {
522- SET_BIT ((statusMapEntry.second )[ruleNumber / 63 ], ruleNumber % 63 );
504+ FROM_NRULES_TO_NELEMENTS (Firewall::maxRules));
505+ for (auto const &rule : rules) {
506+ try {
507+ rule_id = rule->getId ();
508+ rule_state = ChainRuleConntrackEnum_to_int (rule->getConntrack ());
509+ if (rule_state == state)
510+ SET_BIT (bitVector[rule_id / 63 ], rule_id % 63 );
511+ } catch (std::runtime_error re) {
512+ // wildcard rule, set bit to 1
513+ SET_BIT (bitVector[rule_id / 63 ], rule_id % 63 );
523514 }
524515 }
516+ statusMap.insert (
517+ std::pair<uint8_t , std::vector<uint64_t >>(state, bitVector));
525518 }
526519
527520 return false ;
0 commit comments