@@ -6,12 +6,13 @@ use maxminddb::PathElement;
66
77use crate :: config:: AccessControlConfig ;
88
9- pub struct GeoIp {
9+ pub struct IpFilter {
1010 geo_reader : Option < maxminddb:: Reader < Vec < u8 > > > ,
1111 blocked_regions : HashSet < String > ,
12+ blocked_ips : Vec < ipnet:: IpNet > ,
1213}
1314
14- impl GeoIp {
15+ impl IpFilter {
1516 pub async fn from_config (
1617 config : & AccessControlConfig ,
1718 storage_dir : & Path ,
@@ -42,11 +43,30 @@ impl GeoIp {
4243
4344 let blocked_regions = config. blocked_regions . iter ( ) . cloned ( ) . collect ( ) ;
4445
45- Ok ( Self { geo_reader, blocked_regions } )
46+ let blocked_ips = config
47+ . blocked_ips
48+ . iter ( )
49+ . map ( |s| {
50+ s. parse :: < ipnet:: IpNet > ( ) . or_else ( |_| {
51+ // Accept bare IP addresses without CIDR prefix length
52+ Ok ( ipnet:: IpNet :: from ( s. parse :: < IpAddr > ( ) ?) )
53+ } )
54+ } )
55+ . collect :: < Result < Vec < _ > , anyhow:: Error > > ( ) ?;
56+
57+ Ok ( Self { geo_reader, blocked_regions, blocked_ips } )
4658 }
4759
48- /// Returns `true` if the IP is allowed. Fail-open on lookup errors.
60+ /// Returns `true` if the IP is allowed. Fail-open on GeoIP lookup errors.
4961 pub fn check_ip ( & self , ip : IpAddr ) -> bool {
62+ if self . blocked_ips . iter ( ) . any ( |net| net. contains ( & ip) ) {
63+ return false ;
64+ }
65+
66+ self . check_geoip ( ip)
67+ }
68+
69+ fn check_geoip ( & self , ip : IpAddr ) -> bool {
5070 let reader = match & self . geo_reader {
5171 Some ( r) => r,
5272 None => return true ,
@@ -165,14 +185,19 @@ mod tests {
165185
166186 #[ test]
167187 fn check_ip_allows_when_no_geo_reader ( ) {
168- let ac = GeoIp { geo_reader : None , blocked_regions : HashSet :: new ( ) } ;
188+ let ac =
189+ IpFilter { geo_reader : None , blocked_regions : HashSet :: new ( ) , blocked_ips : vec ! [ ] } ;
169190 assert ! ( ac. check_ip( "1.2.3.4" . parse( ) . unwrap( ) ) ) ;
170191 }
171192
172193 #[ test]
173194 fn check_ip_allows_when_no_blocked_regions ( ) {
174195 let reader = test_geo_reader ( ) ;
175- let ac = GeoIp { geo_reader : Some ( reader) , blocked_regions : HashSet :: new ( ) } ;
196+ let ac = IpFilter {
197+ geo_reader : Some ( reader) ,
198+ blocked_regions : HashSet :: new ( ) ,
199+ blocked_ips : vec ! [ ] ,
200+ } ;
176201 assert ! ( ac. check_ip( "2.125.160.216" . parse( ) . unwrap( ) ) ) ;
177202 }
178203
@@ -181,7 +206,7 @@ mod tests {
181206 let reader = test_geo_reader ( ) ;
182207 // 2.125.160.216 is GB in the test database
183208 let blocked_regions: HashSet < String > = [ "GB" ] . iter ( ) . map ( |s| s. to_string ( ) ) . collect ( ) ;
184- let ac = GeoIp { geo_reader : Some ( reader) , blocked_regions } ;
209+ let ac = IpFilter { geo_reader : Some ( reader) , blocked_regions, blocked_ips : vec ! [ ] } ;
185210 assert ! ( !ac. check_ip( "2.125.160.216" . parse( ) . unwrap( ) ) ) ;
186211 }
187212
@@ -190,19 +215,52 @@ mod tests {
190215 let reader = test_geo_reader ( ) ;
191216 // 2.125.160.216 is GB in the test database
192217 let blocked_regions: HashSet < String > = [ "US" ] . iter ( ) . map ( |s| s. to_string ( ) ) . collect ( ) ;
193- let ac = GeoIp { geo_reader : Some ( reader) , blocked_regions } ;
218+ let ac = IpFilter { geo_reader : Some ( reader) , blocked_regions, blocked_ips : vec ! [ ] } ;
194219 assert ! ( ac. check_ip( "2.125.160.216" . parse( ) . unwrap( ) ) ) ;
195220 }
196221
197222 #[ test]
198223 fn check_ip_fail_open_on_unknown_ip ( ) {
199224 let reader = test_geo_reader ( ) ;
200225 let blocked_regions: HashSet < String > = [ "US" ] . iter ( ) . map ( |s| s. to_string ( ) ) . collect ( ) ;
201- let ac = GeoIp { geo_reader : Some ( reader) , blocked_regions } ;
226+ let ac = IpFilter { geo_reader : Some ( reader) , blocked_regions, blocked_ips : vec ! [ ] } ;
202227 // 127.0.0.1 won't be in test DB
203228 assert ! ( ac. check_ip( "127.0.0.1" . parse( ) . unwrap( ) ) ) ;
204229 }
205230
231+ #[ test]
232+ fn blocked_ips_blocks_exact_ipv4 ( ) {
233+ let blocked_ips = vec ! [ "192.0.2.1/32" . parse( ) . unwrap( ) ] ;
234+ let ac = IpFilter { geo_reader : None , blocked_regions : HashSet :: new ( ) , blocked_ips } ;
235+ assert ! ( !ac. check_ip( "192.0.2.1" . parse( ) . unwrap( ) ) ) ;
236+ assert ! ( ac. check_ip( "192.0.2.2" . parse( ) . unwrap( ) ) ) ;
237+ }
238+
239+ #[ test]
240+ fn blocked_ips_blocks_exact_ipv6 ( ) {
241+ let blocked_ips = vec ! [ "2001:db8::1/128" . parse( ) . unwrap( ) ] ;
242+ let ac = IpFilter { geo_reader : None , blocked_regions : HashSet :: new ( ) , blocked_ips } ;
243+ assert ! ( !ac. check_ip( "2001:db8::1" . parse( ) . unwrap( ) ) ) ;
244+ assert ! ( ac. check_ip( "2001:db8::2" . parse( ) . unwrap( ) ) ) ;
245+ }
246+
247+ #[ test]
248+ fn blocked_ips_blocks_cidr_range ( ) {
249+ let blocked_ips = vec ! [ "198.51.100.0/24" . parse( ) . unwrap( ) ] ;
250+ let ac = IpFilter { geo_reader : None , blocked_regions : HashSet :: new ( ) , blocked_ips } ;
251+ assert ! ( !ac. check_ip( "198.51.100.0" . parse( ) . unwrap( ) ) ) ;
252+ assert ! ( !ac. check_ip( "198.51.100.255" . parse( ) . unwrap( ) ) ) ;
253+ assert ! ( ac. check_ip( "198.51.101.0" . parse( ) . unwrap( ) ) ) ;
254+ }
255+
256+ #[ test]
257+ fn empty_blocked_ips_allows_all ( ) {
258+ let ac =
259+ IpFilter { geo_reader : None , blocked_regions : HashSet :: new ( ) , blocked_ips : vec ! [ ] } ;
260+ assert ! ( ac. check_ip( "192.0.2.1" . parse( ) . unwrap( ) ) ) ;
261+ assert ! ( ac. check_ip( "2001:db8::1" . parse( ) . unwrap( ) ) ) ;
262+ }
263+
206264 #[ test]
207265 fn year_month_conversion_handles_leap_day ( ) {
208266 // 2024-02-29 00:00:00 UTC
0 commit comments