33use crate :: severity:: Severity ;
44use fancy_regex:: Regex ;
55use log:: warn;
6+ use std:: collections:: HashSet ;
67use std:: io;
7- use std:: path:: Path ;
8+ use std:: path:: { Path , PathBuf } ;
89
910/// Represents a secret detection rule.
1011pub struct Rule {
@@ -15,12 +16,36 @@ pub struct Rule {
1516
1617/// Loads rules from a TOML file.
1718pub fn load_rules < P : AsRef < Path > > ( path : P ) -> io:: Result < Vec < Rule > > {
19+ let mut visited = HashSet :: new ( ) ;
20+ load_rules_inner ( path. as_ref ( ) , & mut visited)
21+ }
22+
23+ /// Recursively loads rules, tracking visited files to detect circular includes.
24+ fn load_rules_inner ( path : & Path , visited : & mut HashSet < PathBuf > ) -> io:: Result < Vec < Rule > > {
25+ let canonical = path. canonicalize ( ) ?;
26+ if !visited. insert ( canonical) {
27+ return Err ( io:: Error :: new (
28+ io:: ErrorKind :: InvalidData ,
29+ format ! ( "Circular include detected: {}" , path. display( ) ) ,
30+ ) ) ;
31+ }
32+
33+ let base_dir = path. parent ( ) . unwrap_or ( Path :: new ( "." ) ) ;
1834 let content = std:: fs:: read_to_string ( path) ?;
1935 let toml_value: toml:: Value =
2036 toml:: from_str ( & content) . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: InvalidData , e) ) ?;
2137
2238 let mut rules = Vec :: new ( ) ;
2339
40+ if let Some ( includes) = toml_value. get ( "includes" ) . and_then ( |v| v. as_array ( ) ) {
41+ for include in includes {
42+ if let Some ( include_path) = include. as_str ( ) {
43+ let full_path = base_dir. join ( include_path) ;
44+ rules. extend ( load_rules_inner ( & full_path, visited) ?) ;
45+ }
46+ }
47+ }
48+
2449 if let Some ( rules_array) = toml_value. get ( "rules" ) . and_then ( |v| v. as_array ( ) ) {
2550 for rule in rules_array {
2651 let description = match rule. get ( "description" ) . and_then ( |v| v. as_str ( ) ) {
@@ -80,4 +105,64 @@ mod tests {
80105 assert_eq ! ( rules[ 0 ] . severity, Severity :: High ) ;
81106 assert ! ( rules[ 0 ] . regex. is_match( "tok_123" ) . unwrap( ) ) ;
82107 }
108+
109+ #[ test]
110+ fn test_load_rules_with_includes ( ) {
111+ let base_toml = r#"
112+ [[rules]]
113+ description = "Base Rule"
114+ regex = "base_[0-9]+"
115+ severity = "low"
116+ "# ;
117+
118+ let child_toml = r#"
119+ includes = ["base_rules.toml"]
120+
121+ [[rules]]
122+ description = "Child Rule"
123+ regex = "child_[0-9]+"
124+ severity = "high"
125+ "# ;
126+
127+ let tmp_dir = std:: env:: temp_dir ( ) . join ( "lsa_test_includes" ) ;
128+ std:: fs:: create_dir_all ( & tmp_dir) . unwrap ( ) ;
129+
130+ std:: fs:: write ( tmp_dir. join ( "base_rules.toml" ) , base_toml) . unwrap ( ) ;
131+ std:: fs:: write ( tmp_dir. join ( "child_rules.toml" ) , child_toml) . unwrap ( ) ;
132+
133+ let rules = load_rules ( tmp_dir. join ( "child_rules.toml" ) ) . unwrap ( ) ;
134+ assert_eq ! ( rules. len( ) , 2 ) ;
135+ assert_eq ! ( rules[ 0 ] . description, "Base Rule" ) ;
136+ assert_eq ! ( rules[ 1 ] . description, "Child Rule" ) ;
137+ }
138+
139+ #[ test]
140+ fn test_load_rules_circular_include ( ) {
141+ let a_toml = r#"
142+ includes = ["b.toml"]
143+
144+ [[rules]]
145+ description = "Rule A"
146+ regex = "a_[0-9]+"
147+ severity = "low"
148+ "# ;
149+
150+ let b_toml = r#"
151+ includes = ["a.toml"]
152+
153+ [[rules]]
154+ description = "Rule B"
155+ regex = "b_[0-9]+"
156+ severity = "low"
157+ "# ;
158+
159+ let tmp_dir = std:: env:: temp_dir ( ) . join ( "lsa_test_circular" ) ;
160+ std:: fs:: create_dir_all ( & tmp_dir) . unwrap ( ) ;
161+
162+ std:: fs:: write ( tmp_dir. join ( "a.toml" ) , a_toml) . unwrap ( ) ;
163+ std:: fs:: write ( tmp_dir. join ( "b.toml" ) , b_toml) . unwrap ( ) ;
164+
165+ let result = load_rules ( tmp_dir. join ( "a.toml" ) ) ;
166+ assert ! ( result. is_err( ) ) ;
167+ }
83168}
0 commit comments