/* Copyright (C) 2022 Open Information Security Foundation * * You can copy, redistribute or modify this Program under the terms of * the GNU General Public License version 2 as published by the Free * Software Foundation. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * version 2 along with this program; if not, write to the Free Software * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA * 02110-1301, USA. */ use nom7::branch::alt; use nom7::bytes::complete::{is_a, tag, tag_no_case, take_while}; use nom7::character::complete::digit1; use nom7::combinator::{all_consuming, map_opt, opt, value, verify}; use nom7::error::{make_error, ErrorKind}; use nom7::Err; use nom7::IResult; use std::ffi::CStr; #[derive(PartialEq, Eq, Clone, Debug)] #[repr(u8)] pub enum DetectUintMode { DetectUintModeEqual, DetectUintModeLt, DetectUintModeLte, DetectUintModeGt, DetectUintModeGte, DetectUintModeRange, DetectUintModeNe, } #[derive(Debug)] #[repr(C)] pub struct DetectUintData { pub arg1: T, pub arg2: T, pub mode: DetectUintMode, } pub trait DetectIntType: std::str::FromStr + std::cmp::PartialOrd + num::PrimInt + num::Bounded + num::ToPrimitive + num::FromPrimitive { } impl DetectIntType for T where T: std::str::FromStr + std::cmp::PartialOrd + num::PrimInt + num::Bounded + num::ToPrimitive + num::FromPrimitive { } pub fn detect_parse_uint_unit(i: &str) -> IResult<&str, u64> { let (i, unit) = alt(( value(1024, tag_no_case("kb")), value(1024 * 1024, tag_no_case("mb")), value(1024 * 1024 * 1024, tag_no_case("gb")), ))(i)?; return Ok((i, unit)); } pub fn detect_parse_uint_with_unit(i: &str) -> IResult<&str, T> { let (i, arg1) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; let (i, unit) = opt(detect_parse_uint_unit)(i)?; if arg1 >= T::one() { if let Some(u) = unit { if T::max_value().to_u64().unwrap() / u < arg1.to_u64().unwrap() { return Err(Err::Error(make_error(i, ErrorKind::Verify))); } let ru64 = arg1 * T::from_u64(u).unwrap(); return Ok((i, ru64)); } } Ok((i, arg1)) } pub fn detect_parse_uint_start_equal( i: &str, ) -> IResult<&str, DetectUintData> { let (i, _) = opt(tag("="))(i)?; let (i, _) = opt(is_a(" "))(i)?; let (i, arg1) = detect_parse_uint_with_unit(i)?; Ok(( i, DetectUintData { arg1, arg2: T::min_value(), mode: DetectUintMode::DetectUintModeEqual, }, )) } pub fn detect_parse_uint_start_interval( i: &str, ) -> IResult<&str, DetectUintData> { let (i, arg1) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; let (i, _) = opt(is_a(" "))(i)?; let (i, _) = alt((tag("-"), tag("<>")))(i)?; let (i, _) = opt(is_a(" "))(i)?; let (i, arg2) = verify(map_opt(digit1, |s: &str| s.parse::().ok()), |x| { x > &arg1 && *x - arg1 > T::one() })(i)?; Ok(( i, DetectUintData { arg1, arg2, mode: DetectUintMode::DetectUintModeRange, }, )) } fn detect_parse_uint_start_interval_inclusive( i: &str, ) -> IResult<&str, DetectUintData> { let (i, arg1) = verify(map_opt(digit1, |s: &str| s.parse::().ok()), |x| { *x > T::min_value() })(i)?; let (i, _) = opt(is_a(" "))(i)?; let (i, _) = alt((tag("-"), tag("<>")))(i)?; let (i, _) = opt(is_a(" "))(i)?; let (i, arg2) = verify(map_opt(digit1, |s: &str| s.parse::().ok()), |x| { *x > arg1 && *x < T::max_value() })(i)?; Ok(( i, DetectUintData { arg1: arg1 - T::one(), arg2: arg2 + T::one(), mode: DetectUintMode::DetectUintModeRange, }, )) } pub fn detect_parse_uint_mode(i: &str) -> IResult<&str, DetectUintMode> { let (i, mode) = alt(( value(DetectUintMode::DetectUintModeGte, tag(">=")), value(DetectUintMode::DetectUintModeLte, tag("<=")), value(DetectUintMode::DetectUintModeGt, tag(">")), value(DetectUintMode::DetectUintModeLt, tag("<")), value(DetectUintMode::DetectUintModeNe, tag("!=")), value(DetectUintMode::DetectUintModeNe, tag("!")), value(DetectUintMode::DetectUintModeEqual, tag("=")), ))(i)?; return Ok((i, mode)); } fn detect_parse_uint_start_symbol(i: &str) -> IResult<&str, DetectUintData> { let (i, mode) = detect_parse_uint_mode(i)?; let (i, _) = opt(is_a(" "))(i)?; let (i, arg1) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; match mode { DetectUintMode::DetectUintModeNe => {} DetectUintMode::DetectUintModeLt => { if arg1 == T::min_value() { return Err(Err::Error(make_error(i, ErrorKind::Verify))); } } DetectUintMode::DetectUintModeLte => { if arg1 == T::max_value() { return Err(Err::Error(make_error(i, ErrorKind::Verify))); } } DetectUintMode::DetectUintModeGt => { if arg1 == T::max_value() { return Err(Err::Error(make_error(i, ErrorKind::Verify))); } } DetectUintMode::DetectUintModeGte => { if arg1 == T::min_value() { return Err(Err::Error(make_error(i, ErrorKind::Verify))); } } _ => { return Err(Err::Error(make_error(i, ErrorKind::MapOpt))); } } Ok(( i, DetectUintData { arg1, arg2: T::min_value(), mode, }, )) } pub fn detect_match_uint(x: &DetectUintData, val: T) -> bool { match x.mode { DetectUintMode::DetectUintModeEqual => { if val == x.arg1 { return true; } } DetectUintMode::DetectUintModeNe => { if val != x.arg1 { return true; } } DetectUintMode::DetectUintModeLt => { if val < x.arg1 { return true; } } DetectUintMode::DetectUintModeLte => { if val <= x.arg1 { return true; } } DetectUintMode::DetectUintModeGt => { if val > x.arg1 { return true; } } DetectUintMode::DetectUintModeGte => { if val >= x.arg1 { return true; } } DetectUintMode::DetectUintModeRange => { if val > x.arg1 && val < x.arg2 { return true; } } } return false; } pub fn detect_parse_uint_notending(i: &str) -> IResult<&str, DetectUintData> { let (i, _) = opt(is_a(" "))(i)?; let (i, uint) = alt(( detect_parse_uint_start_interval, detect_parse_uint_start_equal, detect_parse_uint_start_symbol, ))(i)?; Ok((i, uint)) } pub fn detect_parse_uint(i: &str) -> IResult<&str, DetectUintData> { let (i, uint) = detect_parse_uint_notending(i)?; let (i, _) = all_consuming(take_while(|c| c == ' '))(i)?; Ok((i, uint)) } pub fn detect_parse_uint_inclusive(i: &str) -> IResult<&str, DetectUintData> { let (i, _) = opt(is_a(" "))(i)?; let (i, uint) = alt(( detect_parse_uint_start_interval_inclusive, detect_parse_uint_start_equal, detect_parse_uint_start_symbol, ))(i)?; let (i, _) = all_consuming(take_while(|c| c == ' '))(i)?; Ok((i, uint)) } #[no_mangle] pub unsafe extern "C" fn rs_detect_u64_parse( ustr: *const std::os::raw::c_char, ) -> *mut DetectUintData { let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe if let Ok(s) = ft_name.to_str() { if let Ok((_, ctx)) = detect_parse_uint::(s) { let boxed = Box::new(ctx); return Box::into_raw(boxed) as *mut _; } } return std::ptr::null_mut(); } #[no_mangle] pub unsafe extern "C" fn rs_detect_u64_match( arg: u64, ctx: &DetectUintData, ) -> std::os::raw::c_int { if detect_match_uint(ctx, arg) { return 1; } return 0; } #[no_mangle] pub unsafe extern "C" fn rs_detect_u64_free(ctx: *mut std::os::raw::c_void) { // Just unbox... std::mem::drop(Box::from_raw(ctx as *mut DetectUintData)); } #[no_mangle] pub unsafe extern "C" fn rs_detect_u32_parse( ustr: *const std::os::raw::c_char, ) -> *mut DetectUintData { let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe if let Ok(s) = ft_name.to_str() { if let Ok((_, ctx)) = detect_parse_uint::(s) { let boxed = Box::new(ctx); return Box::into_raw(boxed) as *mut _; } } return std::ptr::null_mut(); } #[no_mangle] pub unsafe extern "C" fn rs_detect_u32_parse_inclusive( ustr: *const std::os::raw::c_char, ) -> *mut DetectUintData { let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe if let Ok(s) = ft_name.to_str() { if let Ok((_, ctx)) = detect_parse_uint_inclusive::(s) { let boxed = Box::new(ctx); return Box::into_raw(boxed) as *mut _; } } return std::ptr::null_mut(); } #[no_mangle] pub unsafe extern "C" fn rs_detect_u32_match( arg: u32, ctx: &DetectUintData, ) -> std::os::raw::c_int { if detect_match_uint(ctx, arg) { return 1; } return 0; } #[no_mangle] pub unsafe extern "C" fn rs_detect_u32_free(ctx: &mut DetectUintData) { // Just unbox... std::mem::drop(Box::from_raw(ctx)); } #[no_mangle] pub unsafe extern "C" fn rs_detect_u8_parse( ustr: *const std::os::raw::c_char, ) -> *mut DetectUintData { let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe if let Ok(s) = ft_name.to_str() { if let Ok((_, ctx)) = detect_parse_uint::(s) { let boxed = Box::new(ctx); return Box::into_raw(boxed) as *mut _; } } return std::ptr::null_mut(); } #[no_mangle] pub unsafe extern "C" fn rs_detect_u8_match( arg: u8, ctx: &DetectUintData, ) -> std::os::raw::c_int { if detect_match_uint(ctx, arg) { return 1; } return 0; } #[no_mangle] pub unsafe extern "C" fn rs_detect_u8_free(ctx: &mut DetectUintData) { // Just unbox... std::mem::drop(Box::from_raw(ctx)); } #[no_mangle] pub unsafe extern "C" fn rs_detect_u16_parse( ustr: *const std::os::raw::c_char, ) -> *mut DetectUintData { let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe if let Ok(s) = ft_name.to_str() { if let Ok((_, ctx)) = detect_parse_uint::(s) { let boxed = Box::new(ctx); return Box::into_raw(boxed) as *mut _; } } return std::ptr::null_mut(); } #[no_mangle] pub unsafe extern "C" fn rs_detect_u16_match( arg: u16, ctx: &DetectUintData, ) -> std::os::raw::c_int { if detect_match_uint(ctx, arg) { return 1; } return 0; } #[no_mangle] pub unsafe extern "C" fn rs_detect_u16_free(ctx: &mut DetectUintData) { // Just unbox... std::mem::drop(Box::from_raw(ctx)); } #[cfg(test)] mod tests { use super::*; #[test] fn test_parse_uint_unit() { match detect_parse_uint::(" 2kb") { Ok((_, val)) => { assert_eq!(val.arg1, 2048); } Err(_) => { assert!(false); } } match detect_parse_uint::("2kb") { Ok((_, _val)) => { assert!(false); } Err(_) => {} } match detect_parse_uint::("3MB") { Ok((_, val)) => { assert_eq!(val.arg1, 3 * 1024 * 1024); } Err(_) => { assert!(false); } } } }