diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 17:39:49 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 17:39:49 +0000 |
commit | a0aa2307322cd47bbf416810ac0292925e03be87 (patch) | |
tree | 37076262a026c4b48c8a0e84f44ff9187556ca35 /rust/src/detect | |
parent | Initial commit. (diff) | |
download | suricata-a0aa2307322cd47bbf416810ac0292925e03be87.tar.xz suricata-a0aa2307322cd47bbf416810ac0292925e03be87.zip |
Adding upstream version 1:7.0.3.upstream/1%7.0.3
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'rust/src/detect')
-rw-r--r-- | rust/src/detect/byte_math.rs | 1163 | ||||
-rw-r--r-- | rust/src/detect/error.rs | 37 | ||||
-rw-r--r-- | rust/src/detect/iprep.rs | 128 | ||||
-rw-r--r-- | rust/src/detect/mod.rs | 27 | ||||
-rw-r--r-- | rust/src/detect/parser.rs | 38 | ||||
-rw-r--r-- | rust/src/detect/requires.rs | 805 | ||||
-rw-r--r-- | rust/src/detect/stream_size.rs | 98 | ||||
-rw-r--r-- | rust/src/detect/uint.rs | 435 | ||||
-rw-r--r-- | rust/src/detect/uri.rs | 78 |
9 files changed, 2809 insertions, 0 deletions
diff --git a/rust/src/detect/byte_math.rs b/rust/src/detect/byte_math.rs new file mode 100644 index 0000000..80bd3d5 --- /dev/null +++ b/rust/src/detect/byte_math.rs @@ -0,0 +1,1163 @@ +/* Copyright (C) 2022 Open Information Security Foundation + * + * You can copy, redistribute or modify this Program under the terms of + * the GNU General Public License version 2 as published by the Free + * Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * version 2 along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA + * 02110-1301, USA. + */ + +// Author: Jeff Lucovsky <jlucovsky@oisf.net> + +use crate::detect::error::RuleParseError; +use crate::detect::parser::{parse_token, take_until_whitespace}; +use std::ffi::{CStr, CString}; +use std::os::raw::c_char; + +use nom7::bytes::complete::tag; +use nom7::character::complete::multispace0; +use nom7::sequence::preceded; +use nom7::{Err, IResult}; +use std::str; + +pub const DETECT_BYTEMATH_FLAG_RELATIVE: u8 = 0x01; +pub const DETECT_BYTEMATH_FLAG_STRING: u8 = 0x02; +pub const DETECT_BYTEMATH_FLAG_BITMASK: u8 = 0x04; +pub const DETECT_BYTEMATH_FLAG_ENDIAN: u8 = 0x08; +pub const DETECT_BYTEMATH_FLAG_RVALUE_VAR: u8 = 0x10; +pub const DETECT_BYTEMATH_FLAG_NBYTES_VAR: u8 = 0x20; + +// Ensure required values are provided +const DETECT_BYTEMATH_FLAG_NBYTES: u8 = 0x1; +const DETECT_BYTEMATH_FLAG_OFFSET: u8 = 0x2; +const DETECT_BYTEMATH_FLAG_OPER: u8 = 0x4; +const DETECT_BYTEMATH_FLAG_RVALUE: u8 = 0x8; +const DETECT_BYTEMATH_FLAG_RESULT: u8 = 0x10; +const DETECT_BYTEMATH_FLAG_REQUIRED: u8 = DETECT_BYTEMATH_FLAG_RESULT + | DETECT_BYTEMATH_FLAG_RVALUE + | DETECT_BYTEMATH_FLAG_NBYTES + | DETECT_BYTEMATH_FLAG_OFFSET + | DETECT_BYTEMATH_FLAG_OPER; + +#[repr(u8)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +// operators: +, -, /, *, <<, >> +pub enum ByteMathOperator { + OperatorNone = 1, + Addition = 2, + Subtraction = 3, + Division = 4, + Multiplication = 5, + LeftShift = 6, + RightShift = 7, +} + +#[repr(u8)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +// endian <big|little|dce> +pub enum ByteMathEndian { + _EndianNone = 0, + BigEndian = 1, + LittleEndian = 2, + EndianDCE = 3, +} +pub const DETECT_BYTEMATH_ENDIAN_DEFAULT: ByteMathEndian = ByteMathEndian::BigEndian; + +#[repr(u8)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum ByteMathBase { + _BaseNone = 0, + BaseOct = 8, + BaseDec = 10, + BaseHex = 16, +} +const BASE_DEFAULT: ByteMathBase = ByteMathBase::BaseDec; + +// Fixed position parameter count: bytes, offset, oper, rvalue, result +// result is not parsed with the fixed position parameters as it's +// often swapped with optional parameters +pub const DETECT_BYTEMATH_FIXED_PARAM_COUNT: usize = 5; +// Optional parameters: endian, relative, string, dce, bitmask +pub const DETECT_BYTEMATH_MAX_PARAM_COUNT: usize = 10; + +#[derive(Debug)] +enum ResultValue { + Numeric(u64), + String(String), +} + +#[repr(C)] +#[derive(Debug)] +pub struct DetectByteMathData { + rvalue_str: *const c_char, + result: *const c_char, + nbytes_str: *const c_char, + rvalue: u32, + offset: i32, + bitmask_val: u32, + bitmask_shift_count: u16, + id: u16, + flags: u8, + local_id: u8, + nbytes: u8, + oper: ByteMathOperator, + endian: ByteMathEndian, // big, little, dce + base: ByteMathBase, // From string or dce +} + +impl Drop for DetectByteMathData { + fn drop(&mut self) { + unsafe { + if !self.result.is_null() { + let _ = CString::from_raw(self.result as *mut c_char); + } + if !self.rvalue_str.is_null() { + let _ = CString::from_raw(self.rvalue_str as *mut c_char); + } + if !self.nbytes_str.is_null() { + let _ = CString::from_raw(self.nbytes_str as *mut c_char); + } + } + } +} + +impl Default for DetectByteMathData { + fn default() -> Self { + DetectByteMathData { + local_id: 0, + flags: 0, + nbytes: 0, + offset: 0, + oper: ByteMathOperator::OperatorNone, + rvalue_str: std::ptr::null_mut(), + nbytes_str: std::ptr::null_mut(), + rvalue: 0, + result: std::ptr::null_mut(), + endian: DETECT_BYTEMATH_ENDIAN_DEFAULT, + base: BASE_DEFAULT, + bitmask_val: 0, + bitmask_shift_count: 0, + id: 0, + } + } +} + +impl DetectByteMathData { + pub fn new() -> Self { + Self { + ..Default::default() + } + } +} + +fn get_string_value(value: &str) -> Result<ByteMathBase, ()> { + let res = match value { + "hex" => ByteMathBase::BaseHex, + "oct" => ByteMathBase::BaseOct, + "dec" => ByteMathBase::BaseDec, + _ => return Err(()), + }; + + Ok(res) +} + +fn get_oper_value(value: &str) -> Result<ByteMathOperator, ()> { + let res = match value { + "+" => ByteMathOperator::Addition, + "-" => ByteMathOperator::Subtraction, + "/" => ByteMathOperator::Division, + "*" => ByteMathOperator::Multiplication, + "<<" => ByteMathOperator::LeftShift, + ">>" => ByteMathOperator::RightShift, + _ => return Err(()), + }; + + Ok(res) +} + +fn get_endian_value(value: &str) -> Result<ByteMathEndian, ()> { + let res = match value { + "big" => ByteMathEndian::BigEndian, + "little" => ByteMathEndian::LittleEndian, + "dce" => ByteMathEndian::EndianDCE, + _ => return Err(()), + }; + + Ok(res) +} + +// Parsed as a u64 for validation with u32 {min,max} so values greater than uint32 +// are not treated as a string value. +fn parse_var(input: &str) -> IResult<&str, ResultValue, RuleParseError<&str>> { + let (input, value) = parse_token(input)?; + if let Ok(val) = value.parse::<u64>() { + Ok((input, ResultValue::Numeric(val))) + } else { + Ok((input, ResultValue::String(value.to_string()))) + } +} + +fn parse_bytemath(input: &str) -> IResult<&str, DetectByteMathData, RuleParseError<&str>> { + // Inner utility function for easy error creation. + fn make_error(reason: String) -> nom7::Err<RuleParseError<&'static str>> { + Err::Error(RuleParseError::InvalidByteMath(reason)) + } + let (_, values) = nom7::multi::separated_list1( + tag(","), + preceded(multispace0, nom7::bytes::complete::is_not(",")), + )(input)?; + + if values.len() < DETECT_BYTEMATH_FIXED_PARAM_COUNT + || values.len() > DETECT_BYTEMATH_MAX_PARAM_COUNT + { + return Err(make_error(format!("Incorrect argument string; at least {} values must be specified but no more than {}: {:?}", + DETECT_BYTEMATH_FIXED_PARAM_COUNT, DETECT_BYTEMATH_MAX_PARAM_COUNT, input))); + } + + let mut required_flags: u8 = 0; + let mut byte_math = DetectByteMathData::new(); + //for value in &values[0..] { + for value in values { + let (mut val, mut name) = take_until_whitespace(value)?; + val = val.trim(); + name = name.trim(); + match name { + "oper" => { + if 0 != (required_flags & DETECT_BYTEMATH_FLAG_OPER) { + return Err(make_error("operator already set".to_string())); + } + byte_math.oper = match get_oper_value(val) { + Ok(val) => val, + Err(_) => { + return Err(make_error(format!("unknown oper value {}", val))); + } + }; + required_flags |= DETECT_BYTEMATH_FLAG_OPER; + } + "result" => { + if 0 != (required_flags & DETECT_BYTEMATH_FLAG_RESULT) { + return Err(make_error("result already set".to_string())); + } + let tmp: String = val + .parse() + .map_err(|_| make_error(format!("invalid result: {}", val)))?; + match CString::new(tmp) { + Ok(strval) => { + byte_math.result = strval.into_raw(); + required_flags |= DETECT_BYTEMATH_FLAG_RESULT; + } + _ => { + return Err(make_error( + "parse string not safely convertible to C".to_string(), + )); + } + } + } + "rvalue" => { + if 0 != (required_flags & DETECT_BYTEMATH_FLAG_RVALUE) { + return Err(make_error("rvalue already set".to_string())); + } + let (_, res) = parse_var(val)?; + match res { + ResultValue::Numeric(val) => { + if val >= u32::MIN.into() && val <= u32::MAX.into() { + byte_math.rvalue = val as u32 + } else { + return Err(make_error(format!( + "invalid rvalue value: must be between {} and {}: {}", + 1, + u32::MAX, + val + ))); + } + } + ResultValue::String(val) => match CString::new(val) { + Ok(newval) => { + byte_math.rvalue_str = newval.into_raw(); + byte_math.flags |= DETECT_BYTEMATH_FLAG_RVALUE_VAR; + } + _ => { + return Err(make_error( + "parse string not safely convertible to C".to_string(), + )) + } + }, + } + required_flags |= DETECT_BYTEMATH_FLAG_RVALUE; + } + "endian" => { + if 0 != (byte_math.flags & DETECT_BYTEMATH_FLAG_ENDIAN) { + return Err(make_error("endianess already set".to_string())); + } + byte_math.endian = match get_endian_value(val) { + Ok(val) => val, + Err(_) => { + return Err(make_error(format!("invalid endian value: {}", val))); + } + }; + byte_math.flags |= DETECT_BYTEMATH_FLAG_ENDIAN; + } + "dce" => { + if 0 != (byte_math.flags & DETECT_BYTEMATH_FLAG_ENDIAN) { + return Err(make_error("endianess already set".to_string())); + } + byte_math.flags |= DETECT_BYTEMATH_FLAG_ENDIAN; + byte_math.endian = ByteMathEndian::EndianDCE; + } + "string" => { + if 0 != (byte_math.flags & DETECT_BYTEMATH_FLAG_STRING) { + return Err(make_error("string already set".to_string())); + } + byte_math.base = match get_string_value(val) { + Ok(val) => val, + Err(_) => { + return Err(make_error(format!("invalid string value: {}", val))); + } + }; + byte_math.flags |= DETECT_BYTEMATH_FLAG_STRING; + } + "relative" => { + if 0 != (byte_math.flags & DETECT_BYTEMATH_FLAG_RELATIVE) { + return Err(make_error("relative already set".to_string())); + } + byte_math.flags |= DETECT_BYTEMATH_FLAG_RELATIVE; + } + "bitmask" => { + if 0 != (byte_math.flags & DETECT_BYTEMATH_FLAG_BITMASK) { + return Err(make_error("bitmask already set".to_string())); + } + let trimmed = if val.starts_with("0x") || val.starts_with("0X") { + &val[2..] + } else { + val + }; + + let val = u32::from_str_radix(trimmed, 16) + .map_err(|_| make_error(format!("invalid bitmask value: {}", value)))?; + byte_math.bitmask_val = val; + byte_math.flags |= DETECT_BYTEMATH_FLAG_BITMASK; + } + "offset" => { + if 0 != (required_flags & DETECT_BYTEMATH_FLAG_OFFSET) { + return Err(make_error("offset already set".to_string())); + } + byte_math.offset = val + .parse::<i32>() + .map_err(|_| make_error(format!("invalid offset value: {}", val)))?; + if byte_math.offset > 65535 || byte_math.offset < -65535 { + return Err(make_error(format!( + "invalid offset value: must be between -65535 and 65535: {}", + val + ))); + } + required_flags |= DETECT_BYTEMATH_FLAG_OFFSET; + } + "bytes" => { + if 0 != (required_flags & DETECT_BYTEMATH_FLAG_NBYTES) { + return Err(make_error("nbytes already set".to_string())); + } + let (_, res) = parse_var(val)?; + match res { + ResultValue::Numeric(val) => { + if (1..=10).contains(&val) { + byte_math.nbytes = val as u8 + } else { + return Err(make_error(format!( + "invalid nbytes value: must be between 1 and 10: {}", + val + ))); + } + } + ResultValue::String(val) => match CString::new(val) { + Ok(newval) => { + byte_math.nbytes_str = newval.into_raw(); + byte_math.flags |= DETECT_BYTEMATH_FLAG_NBYTES_VAR; + } + _ => { + return Err(make_error( + "parse string not safely convertible to C".to_string(), + )) + } + }, + } + required_flags |= DETECT_BYTEMATH_FLAG_NBYTES; + } + _ => { + return Err(make_error(format!("unknown byte_math keyword: {}", name))); + } + }; + } + + // Ensure required values are present + if (required_flags & DETECT_BYTEMATH_FLAG_REQUIRED) != DETECT_BYTEMATH_FLAG_REQUIRED { + return Err(make_error(format!( + "required byte_math parameters missing: \"{:?}\"", + input + ))); + } + + // Using left/right shift further restricts the value of nbytes. Note that + // validation has already ensured nbytes is in [1..10] + match byte_math.oper { + ByteMathOperator::LeftShift | ByteMathOperator::RightShift => { + if byte_math.nbytes > 4 { + return Err(make_error(format!("nbytes must be 1 through 4 (inclusive) when used with \"<<\" or \">>\"; {} is not valid", byte_math.nbytes))); + } + } + _ => {} + }; + + Ok((input, byte_math)) +} + +/// Intermediary function between the C code and the parsing functions. +#[no_mangle] +pub unsafe extern "C" fn ScByteMathParse(c_arg: *const c_char) -> *mut DetectByteMathData { + if c_arg.is_null() { + return std::ptr::null_mut(); + } + + let arg = match CStr::from_ptr(c_arg).to_str() { + Ok(arg) => arg, + Err(_) => { + return std::ptr::null_mut(); + } + }; + match parse_bytemath(arg) { + Ok((_, detect)) => return Box::into_raw(Box::new(detect)), + Err(_) => return std::ptr::null_mut(), + } +} + +#[no_mangle] +pub unsafe extern "C" fn ScByteMathFree(ptr: *mut DetectByteMathData) { + if !ptr.is_null() { + let _ = Box::from_raw(ptr); + } +} + +#[cfg(test)] +mod tests { + use super::*; + // structure equality only used by test cases + impl PartialEq for DetectByteMathData { + fn eq(&self, other: &Self) -> bool { + let mut res: bool = false; + + if !self.rvalue_str.is_null() && !other.rvalue_str.is_null() { + let s_val = unsafe { CStr::from_ptr(self.rvalue_str) }; + let o_val = unsafe { CStr::from_ptr(other.rvalue_str) }; + res = s_val == o_val; + } else if !self.rvalue_str.is_null() || !other.rvalue_str.is_null() { + return false; + } + + if !self.nbytes_str.is_null() && !other.nbytes_str.is_null() { + let s_val = unsafe { CStr::from_ptr(self.nbytes_str) }; + let o_val = unsafe { CStr::from_ptr(other.nbytes_str) }; + res = s_val == o_val; + } else if !self.nbytes_str.is_null() || !other.nbytes_str.is_null() { + return false; + } + + if !self.result.is_null() && !self.result.is_null() { + let s_val = unsafe { CStr::from_ptr(self.result) }; + let o_val = unsafe { CStr::from_ptr(other.result) }; + res = s_val == o_val; + } else if !self.result.is_null() || !self.result.is_null() { + return false; + } + + res && self.local_id == other.local_id + && self.flags == other.flags + && self.nbytes == other.nbytes + && self.offset == other.offset + && self.oper == other.oper + && self.rvalue == other.rvalue + && self.endian == other.endian + && self.base == other.base + && self.bitmask_val == other.bitmask_val + && self.bitmask_shift_count == other.bitmask_shift_count + && self.id == other.id + } + } + + fn valid_test( + args: &str, nbytes: u8, offset: i32, oper: ByteMathOperator, rvalue_str: &str, nbytes_str: &str, rvalue: u32, + result: &str, base: ByteMathBase, endian: ByteMathEndian, bitmask_val: u32, flags: u8, + ) { + let bmd = DetectByteMathData { + nbytes, + offset, + oper, + rvalue_str: if !rvalue_str.is_empty() { + CString::new(rvalue_str).unwrap().into_raw() + } else { + std::ptr::null_mut() + }, + nbytes_str: if !nbytes_str.is_empty() { + CString::new(nbytes_str).unwrap().into_raw() + } else { + std::ptr::null_mut() + }, + rvalue, + result: CString::new(result).unwrap().into_raw(), + base, + endian, + bitmask_val, + flags, + ..Default::default() + }; + + match parse_bytemath(args) { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + } + + #[test] + fn test_parser_valid() { + valid_test( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result myresult, dce, string dec", + 4, + 3933, + ByteMathOperator::Addition, + "myrvalue", + "", + 0, + "myresult", + ByteMathBase::BaseDec, + ByteMathEndian::EndianDCE, + 0, + DETECT_BYTEMATH_FLAG_RVALUE_VAR + | DETECT_BYTEMATH_FLAG_STRING + | DETECT_BYTEMATH_FLAG_ENDIAN, + ); + + valid_test( + "bytes 4, offset 3933, oper +, rvalue 99, result other, dce, string dec", + 4, + 3933, + ByteMathOperator::Addition, + "", + "", + 99, + "other", + ByteMathBase::BaseDec, + ByteMathEndian::EndianDCE, + 0, + DETECT_BYTEMATH_FLAG_STRING | DETECT_BYTEMATH_FLAG_ENDIAN, + ); + + valid_test( + "bytes 4, offset -3933, oper +, rvalue myrvalue, result foo", + 4, + -3933, + ByteMathOperator::Addition, + "rvalue", + "", + 0, + "foo", + BASE_DEFAULT, + ByteMathEndian::BigEndian, + 0, + DETECT_BYTEMATH_FLAG_RVALUE_VAR, + ); + + valid_test( + "bytes nbytes_var, offset -3933, oper +, rvalue myrvalue, result foo", + 0, + -3933, + ByteMathOperator::Addition, + "rvalue", + "nbytes_var", + 0, + "foo", + BASE_DEFAULT, + ByteMathEndian::BigEndian, + 0, + DETECT_BYTEMATH_FLAG_RVALUE_VAR | DETECT_BYTEMATH_FLAG_NBYTES_VAR, + ); + + // Out of order + valid_test( + "string dec, endian big, result other, rvalue 99, oper +, offset 3933, bytes 4", + 4, + 3933, + ByteMathOperator::Addition, + "", + "", + 99, + "other", + ByteMathBase::BaseDec, + ByteMathEndian::BigEndian, + 0, + DETECT_BYTEMATH_FLAG_STRING | DETECT_BYTEMATH_FLAG_ENDIAN, + ); + } + + #[test] + fn test_parser_string_valid() { + let mut bmd = DetectByteMathData { + nbytes: 4, + offset: 3933, + oper: ByteMathOperator::Addition, + rvalue_str: CString::new("myrvalue").unwrap().into_raw(), + rvalue: 0, + result: CString::new("foo").unwrap().into_raw(), + endian: DETECT_BYTEMATH_ENDIAN_DEFAULT, + base: ByteMathBase::BaseDec, + flags: DETECT_BYTEMATH_FLAG_RVALUE_VAR | DETECT_BYTEMATH_FLAG_STRING, + ..Default::default() + }; + + match parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, string dec", + ) { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + + bmd.flags = DETECT_BYTEMATH_FLAG_RVALUE_VAR; + bmd.base = BASE_DEFAULT; + match parse_bytemath("bytes 4, offset 3933, oper +, rvalue myrvalue, result foo") { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + + bmd.flags = DETECT_BYTEMATH_FLAG_RVALUE_VAR | DETECT_BYTEMATH_FLAG_STRING; + bmd.base = ByteMathBase::BaseHex; + match parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, string hex", + ) { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + + bmd.base = ByteMathBase::BaseOct; + match parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, string oct", + ) { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + } + + #[test] + fn test_parser_string_invalid() { + assert!( + parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, string decimal" + ) + .is_err() + ); + assert!( + parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, string hexadecimal" + ) + .is_err() + ); + assert!( + parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, string octal" + ) + .is_err() + ); + } + + #[test] + // bytes must be between 1 and 10; when combined with rshift/lshift, must be 4 or less + fn test_parser_bytes_invalid() { + assert!( + parse_bytemath("bytes 0, offset 3933, oper +, rvalue myrvalue, result foo").is_err() + ); + assert!( + parse_bytemath("bytes 11, offset 3933, oper +, rvalue myrvalue, result foo").is_err() + ); + assert!( + parse_bytemath("bytes 5, offset 3933, oper >>, rvalue myrvalue, result foo").is_err() + ); + assert!( + parse_bytemath("bytes 5, offset 3933, oper <<, rvalue myrvalue, result foo").is_err() + ); + } + + #[test] + fn test_parser_bitmask_invalid() { + assert!( + parse_bytemath("bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, bitmask 0x") + .is_err() + ); + assert!( + parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, bitmask x12345678" + ) + .is_err() + ); + assert!( + parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, bitmask X12345678" + ) + .is_err() + ); + assert!( + parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, bitmask 0x123456789012" + ) + .is_err() + ); + assert!( + parse_bytemath("bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, bitmask 0q") + .is_err() + ); + assert!( + parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, bitmask maple" + ) + .is_err() + ); + assert!( + parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, bitmask 0xGHIJKLMN" + ) + .is_err() + ); + assert!( + parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, bitmask #*#*@-" + ) + .is_err() + ); + } + + #[test] + fn test_parser_bitmask_valid() { + let mut bmd = DetectByteMathData { + nbytes: 4, + offset: 3933, + oper: ByteMathOperator::Addition, + rvalue_str: CString::new("myrvalue").unwrap().into_raw(), + rvalue: 0, + result: CString::new("foo").unwrap().into_raw(), + endian: ByteMathEndian::BigEndian, + base: BASE_DEFAULT, + flags: DETECT_BYTEMATH_FLAG_RVALUE_VAR | DETECT_BYTEMATH_FLAG_BITMASK, + ..Default::default() + }; + + bmd.bitmask_val = 0x12345678; + match parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, bitmask 0x12345678", + ) { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + + bmd.bitmask_val = 0xffff1234; + match parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, bitmask ffff1234", + ) { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + + bmd.bitmask_val = 0xffff1234; + match parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, bitmask 0Xffff1234", + ) { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + } + #[test] + fn test_parser_endian_valid() { + let mut bmd = DetectByteMathData { + nbytes: 4, + offset: 3933, + oper: ByteMathOperator::Addition, + rvalue_str: CString::new("myrvalue").unwrap().into_raw(), + rvalue: 0, + result: CString::new("foo").unwrap().into_raw(), + endian: ByteMathEndian::BigEndian, + base: BASE_DEFAULT, + flags: DETECT_BYTEMATH_FLAG_RVALUE_VAR | DETECT_BYTEMATH_FLAG_ENDIAN, + ..Default::default() + }; + + match parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, endian big", + ) { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + + bmd.endian = ByteMathEndian::LittleEndian; + match parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, endian little", + ) { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + + bmd.endian = ByteMathEndian::EndianDCE; + match parse_bytemath("bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, dce") { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + + bmd.endian = DETECT_BYTEMATH_ENDIAN_DEFAULT; + bmd.flags = DETECT_BYTEMATH_FLAG_RVALUE_VAR; + match parse_bytemath("bytes 4, offset 3933, oper +, rvalue myrvalue, result foo") { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + } + + #[test] + fn test_parser_endian_invalid() { + assert!( + parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, endian bigger" + ) + .is_err() + ); + assert!( + parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, endian smaller" + ) + .is_err() + ); + + // endianess can only be specified once + assert!( + parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, endian big, dce" + ) + .is_err() + ); + assert!( + parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, endian small, endian big" + ) + .is_err() + ); + assert!( + parse_bytemath( + "bytes 4, offset 3933, oper +, rvalue myrvalue, result foo, endian small, dce" + ) + .is_err() + ); + } + + #[test] + fn test_parser_oper_valid() { + let mut bmd = DetectByteMathData { + nbytes: 4, + offset: 3933, + oper: ByteMathOperator::Addition, + rvalue_str: CString::new("myrvalue").unwrap().into_raw(), + rvalue: 0, + result: CString::new("foo").unwrap().into_raw(), + endian: ByteMathEndian::BigEndian, + base: BASE_DEFAULT, + flags: DETECT_BYTEMATH_FLAG_RVALUE_VAR, + ..Default::default() + }; + + match parse_bytemath("bytes 4, offset 3933, oper +, rvalue myrvalue, result foo") { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + + bmd.oper = ByteMathOperator::Subtraction; + match parse_bytemath("bytes 4, offset 3933, oper -, rvalue myrvalue, result foo") { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + + bmd.oper = ByteMathOperator::Multiplication; + match parse_bytemath("bytes 4, offset 3933, oper *, rvalue myrvalue, result foo") { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + bmd.oper = ByteMathOperator::Division; + match parse_bytemath("bytes 4, offset 3933, oper /, rvalue myrvalue, result foo") { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + bmd.oper = ByteMathOperator::RightShift; + match parse_bytemath("bytes 4, offset 3933, oper >>, rvalue myrvalue, result foo") { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + bmd.oper = ByteMathOperator::LeftShift; + match parse_bytemath("bytes 4, offset 3933, oper <<, rvalue myrvalue, result foo") { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + } + + #[test] + fn test_parser_oper_invalid() { + assert!( + parse_bytemath("bytes 4, offset 0, oper !, rvalue myvalue, result foo").is_err() + ); + assert!( + parse_bytemath("bytes 4, offset 0, oper ^, rvalue myvalue, result foo").is_err() + ); + assert!( + parse_bytemath("bytes 4, offset 0, oper <>, rvalue myvalue, result foo").is_err() + ); + assert!( + parse_bytemath("bytes 4, offset 0, oper ><, rvalue myvalue, result foo").is_err() + ); + assert!( + parse_bytemath("bytes 4, offset 0, oper <, rvalue myvalue, result foo").is_err() + ); + assert!( + parse_bytemath("bytes 4, offset 0, oper >, rvalue myvalue, result foo").is_err() + ); + } + + #[test] + fn test_parser_rvalue_valid() { + let mut bmd = DetectByteMathData { + nbytes: 4, + offset: 47303, + oper: ByteMathOperator::Multiplication, + rvalue_str: std::ptr::null_mut(), + rvalue: 4294967295, + result: CString::new("foo").unwrap().into_raw(), + endian: DETECT_BYTEMATH_ENDIAN_DEFAULT, + base: BASE_DEFAULT, + ..Default::default() + }; + + match parse_bytemath("bytes 4, offset 47303, oper *, rvalue 4294967295 , result foo") { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + + bmd.rvalue = 1; + match parse_bytemath("bytes 4, offset 47303, oper *, rvalue 1, result foo") { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + bmd.rvalue = 0; + match parse_bytemath("bytes 4, offset 47303, oper *, rvalue 0, result foo") { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + } + + #[test] + fn test_parser_rvalue_invalid() { + assert!( + parse_bytemath("bytes 4, offset 47303, oper *, rvalue 4294967296, result foo").is_err() + ); + } + + #[test] + fn test_parser_offset_valid() { + let mut bmd = DetectByteMathData { + nbytes: 4, + offset: -65535, + oper: ByteMathOperator::Multiplication, + rvalue_str: CString::new("myrvalue").unwrap().into_raw(), + rvalue: 0, + result: CString::new("foo").unwrap().into_raw(), + endian: DETECT_BYTEMATH_ENDIAN_DEFAULT, + base: BASE_DEFAULT, + flags: DETECT_BYTEMATH_FLAG_RVALUE_VAR, + ..Default::default() + }; + + match parse_bytemath("bytes 4, offset -65535, oper *, rvalue myrvalue, result foo") { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + + bmd.offset = 65535; + match parse_bytemath("bytes 4, offset 65535, oper *, rvalue myrvalue, result foo") { + Ok((_, val)) => { + assert_eq!(val, bmd); + } + Err(_) => { + assert!(false); + } + } + } + + #[test] + // offset: numeric values must be between -65535 and 65535 + fn test_parser_offset_invalid() { + assert!( + parse_bytemath("bytes 4, offset -70000, oper *, rvalue myvalue, result foo").is_err() + ); + assert!( + parse_bytemath("bytes 4, offset 70000, oper +, rvalue myvalue, result foo").is_err() + ); + } + + #[test] + fn test_parser_incomplete_args() { + assert!(parse_bytemath("").is_err()); + assert!(parse_bytemath("bytes 4").is_err()); + assert!(parse_bytemath("bytes 4, offset 0").is_err()); + assert!(parse_bytemath("bytes 4, offset 0, oper <<").is_err()); + } + + #[test] + fn test_parser_missing_required() { + assert!( + parse_bytemath("endian big, offset 3933, oper +, rvalue myrvalue, result foo").is_err() + ); + assert!( + parse_bytemath("bytes 4, endian big, oper +, rvalue myrvalue, result foo,").is_err() + ); + assert!( + parse_bytemath("bytes 4, offset 3933, endian big, rvalue myrvalue, result foo") + .is_err() + ); + assert!( + parse_bytemath("bytes 4, offset 3933, oper +, endian big, result foo").is_err() + ); + assert!( + parse_bytemath("bytes 4, offset 3933, oper +, rvalue myrvalue, endian big").is_err() + ); + } + + #[test] + fn test_parser_invalid_args() { + assert!(parse_bytemath("monkey banana").is_err()); + assert!(parse_bytemath("bytes nan").is_err()); + assert!(parse_bytemath("bytes 4, offset nan").is_err()); + assert!(parse_bytemath("bytes 4, offset 0, three 3, four 4, five 5, six 6, seven 7, eight 8, nine 9, ten 10, eleven 11").is_err()); + assert!( + parse_bytemath("bytes 4, offset 0, oper ><, rvalue myrvalue").is_err() + ); + assert!( + parse_bytemath("bytes 4, offset 0, oper +, rvalue myrvalue, endian endian").is_err() + ); + } + #[test] + fn test_parser_multiple() { + assert!( + parse_bytemath( + "bytes 4, bytes 4, offset 0, oper +, rvalue myrvalue, result myresult, endian big" + ) + .is_err() + ); + assert!( + parse_bytemath( + "bytes 4, offset 0, offset 0, oper +, rvalue myrvalue, result myresult, endian big" + ) + .is_err() + ); + assert!( + parse_bytemath( + "bytes 4, offset 0, oper +, oper +, rvalue myrvalue, result myresult, endian big" + ) + .is_err() + ); + assert!(parse_bytemath("bytes 4, offset 0, oper +, rvalue myrvalue, rvalue myrvalue, result myresult, endian big").is_err()); + assert!(parse_bytemath("bytes 4, offset 0, oper +, rvalue myrvalue, result myresult, result myresult, endian big").is_err()); + assert!(parse_bytemath("bytes 4, offset 0, oper +, rvalue myrvalue, result myresult, endian big, endian big").is_err()); + } +} diff --git a/rust/src/detect/error.rs b/rust/src/detect/error.rs new file mode 100644 index 0000000..4959e2c --- /dev/null +++ b/rust/src/detect/error.rs @@ -0,0 +1,37 @@ +/* Copyright (C) 2022 Open Information Security Foundation + * + * You can copy, redistribute or modify this Program under the terms of + * the GNU General Public License version 2 as published by the Free + * Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * version 2 along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA + * 02110-1301, USA. + */ + +use nom7::error::{ErrorKind, ParseError}; + +/// Custom rule parse errors. +/// +/// Implemented based on the Nom example for implementing custom errors. +#[derive(Debug, PartialEq, Eq)] +pub enum RuleParseError<I> { + InvalidByteMath(String), + + Nom(I, ErrorKind), +} +impl<I> ParseError<I> for RuleParseError<I> { + fn from_error_kind(input: I, kind: ErrorKind) -> Self { + RuleParseError::Nom(input, kind) + } + + fn append(_: I, _: ErrorKind, other: Self) -> Self { + other + } +} diff --git a/rust/src/detect/iprep.rs b/rust/src/detect/iprep.rs new file mode 100644 index 0000000..16f5d9d --- /dev/null +++ b/rust/src/detect/iprep.rs @@ -0,0 +1,128 @@ +/* Copyright (C) 2022 Open Information Security Foundation + * + * You can copy, redistribute or modify this Program under the terms of + * the GNU General Public License version 2 as published by the Free + * Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * version 2 along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA + * 02110-1301, USA. + */ + +use super::uint::*; +use nom7::bytes::complete::{is_a, take_while}; +use nom7::character::complete::{alpha0, char, digit1}; +use nom7::combinator::{all_consuming, map_opt, map_res, opt}; +use nom7::error::{make_error, ErrorKind}; +use nom7::Err; +use nom7::IResult; + +use std::ffi::{CStr, CString}; +use std::os::raw::c_char; +use std::str::FromStr; + +#[repr(u8)] +#[derive(Clone, Copy, PartialEq, Eq, FromPrimitive, Debug)] +pub enum DetectIPRepDataCmd { + IPRepCmdAny = 0, + IPRepCmdBoth = 1, + IPRepCmdSrc = 2, + IPRepCmdDst = 3, +} + +impl std::str::FromStr for DetectIPRepDataCmd { + type Err = String; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + match s { + "any" => Ok(DetectIPRepDataCmd::IPRepCmdAny), + "both" => Ok(DetectIPRepDataCmd::IPRepCmdBoth), + "src" => Ok(DetectIPRepDataCmd::IPRepCmdSrc), + "dst" => Ok(DetectIPRepDataCmd::IPRepCmdDst), + _ => Err(format!( + "'{}' is not a valid value for DetectIPRepDataCmd", + s + )), + } + } +} + +#[derive(Debug)] +#[repr(C)] +pub struct DetectIPRepData { + pub du8: DetectUintData<u8>, + pub cat: u8, + pub cmd: DetectIPRepDataCmd, +} + +pub fn is_alphanumeric_or_slash(chr: char) -> bool { + if chr.is_ascii_alphanumeric() { + return true; + } + if chr == '_' || chr == '-' { + return true; + } + return false; +} + +extern "C" { + pub fn SRepCatGetByShortname(name: *const c_char) -> u8; +} + +pub fn detect_parse_iprep(i: &str) -> IResult<&str, DetectIPRepData> { + let (i, _) = opt(is_a(" "))(i)?; + let (i, cmd) = map_res(alpha0, DetectIPRepDataCmd::from_str)(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, _) = char(',')(i)?; + let (i, _) = opt(is_a(" "))(i)?; + + let (i, name) = take_while(is_alphanumeric_or_slash)(i)?; + // copy as to have final zero + let namez = CString::new(name).unwrap(); + let cat = unsafe { SRepCatGetByShortname(namez.as_ptr()) }; + if cat == 0 { + return Err(Err::Error(make_error(i, ErrorKind::MapOpt))); + } + + let (i, _) = opt(is_a(" "))(i)?; + let (i, _) = char(',')(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, mode) = detect_parse_uint_mode(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, _) = char(',')(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, arg1) = map_opt(digit1, |s: &str| s.parse::<u8>().ok())(i)?; + let (i, _) = all_consuming(take_while(|c| c == ' '))(i)?; + let du8 = DetectUintData::<u8> { + arg1, + arg2: 0, + mode, + }; + return Ok((i, DetectIPRepData { du8, cat, cmd })); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_iprep_parse( + ustr: *const std::os::raw::c_char, +) -> *mut DetectIPRepData { + let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe + if let Ok(s) = ft_name.to_str() { + if let Ok((_, ctx)) = detect_parse_iprep(s) { + let boxed = Box::new(ctx); + return Box::into_raw(boxed) as *mut _; + } + } + return std::ptr::null_mut(); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_iprep_free(ctx: &mut DetectIPRepData) { + // Just unbox... + std::mem::drop(Box::from_raw(ctx)); +} diff --git a/rust/src/detect/mod.rs b/rust/src/detect/mod.rs new file mode 100644 index 0000000..d33c9ae --- /dev/null +++ b/rust/src/detect/mod.rs @@ -0,0 +1,27 @@ +/* Copyright (C) 2022 Open Information Security Foundation + * + * You can copy, redistribute or modify this Program under the terms of + * the GNU General Public License version 2 as published by the Free + * Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * version 2 along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA + * 02110-1301, USA. + */ + +//! Module for rule parsing. + +pub mod byte_math; +pub mod error; +pub mod iprep; +pub mod parser; +pub mod stream_size; +pub mod uint; +pub mod uri; +pub mod requires; diff --git a/rust/src/detect/parser.rs b/rust/src/detect/parser.rs new file mode 100644 index 0000000..0ac5846 --- /dev/null +++ b/rust/src/detect/parser.rs @@ -0,0 +1,38 @@ +/* Copyright (C) 2022 Open Information Security Foundation + * + * You can copy, redistribute or modify this Program under the terms of + * the GNU General Public License version 2 as published by the Free + * Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * version 2 along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA + * 02110-1301, USA. + */ + +use crate::detect::error::RuleParseError; + +use nom7::bytes::complete::is_not; +use nom7::character::complete::multispace0; +use nom7::sequence::preceded; +use nom7::IResult; + +static WHITESPACE: &str = " \t\r\n"; +/// Parse all characters up until the next whitespace character. +pub fn take_until_whitespace(input: &str) -> IResult<&str, &str, RuleParseError<&str>> { + nom7::bytes::complete::is_not(WHITESPACE)(input) +} + +/// Parse the next token ignoring leading whitespace. +/// +/// A token is the next sequence of chars until a terminating character. Leading whitespace +/// is ignore. +pub fn parse_token(input: &str) -> IResult<&str, &str, RuleParseError<&str>> { + let terminators = "\n\r\t,;: "; + preceded(multispace0, is_not(terminators))(input) +} diff --git a/rust/src/detect/requires.rs b/rust/src/detect/requires.rs new file mode 100644 index 0000000..e9e1aca --- /dev/null +++ b/rust/src/detect/requires.rs @@ -0,0 +1,805 @@ +/* Copyright (C) 2023 Open Information Security Foundation + * + * You can copy, redistribute or modify this Program under the terms of + * the GNU General Public License version 2 as published by the Free + * Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * version 2 along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA + * 02110-1301, USA. + */ + +use std::collections::{HashSet, VecDeque}; +use std::{cmp::Ordering, ffi::CStr}; + +// std::ffi::{c_char, c_int} is recommended these days, but requires +// Rust 1.64.0. +use std::os::raw::{c_char, c_int}; + +use nom7::bytes::complete::take_while; +use nom7::combinator::map; +use nom7::multi::{many1, separated_list1}; +use nom7::sequence::tuple; +use nom7::{ + branch::alt, + bytes::complete::{tag, take_till}, + character::complete::{char, multispace0}, + combinator::map_res, + sequence::preceded, + IResult, +}; + +#[derive(Debug, Eq, PartialEq)] +enum RequiresError { + /// Suricata is greater than the required version. + VersionGt, + + /// Suricata is less than the required version. + VersionLt(SuricataVersion), + + /// The running Suricata is missing a required feature. + MissingFeature(String), + + /// The Suricata version, of Suricata itself is bad and failed to parse. + BadSuricataVersion, + + /// The requires expression is bad and failed to parse. + BadRequires, + + /// MultipleVersions + MultipleVersions, + + /// Passed in requirements not a valid UTF-8 string. + Utf8Error, +} + +impl RequiresError { + /// Return a pointer to a C compatible constant error message. + const fn c_errmsg(&self) -> *const c_char { + let msg = match self { + Self::VersionGt => "Suricata version greater than required\0", + Self::VersionLt(_) => "Suricata version less than required\0", + Self::MissingFeature(_) => "Suricata missing a required feature\0", + Self::BadSuricataVersion => "Failed to parse running Suricata version\0", + Self::BadRequires => "Failed to parse requires expression\0", + Self::MultipleVersions => "Version may only be specified once\0", + Self::Utf8Error => "Requires expression is not valid UTF-8\0", + }; + msg.as_ptr() as *const c_char + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +enum VersionCompareOp { + Gt, + Gte, + Lt, + Lte, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +struct SuricataVersion { + major: u8, + minor: u8, + patch: u8, +} + +impl PartialOrd for SuricataVersion { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + Some(self.cmp(other)) + } +} + +impl Ord for SuricataVersion { + fn cmp(&self, other: &Self) -> Ordering { + match self.major.cmp(&other.major) { + Ordering::Equal => match self.minor.cmp(&other.minor) { + Ordering::Equal => self.patch.cmp(&other.patch), + other => other, + }, + other => other, + } + } +} + +impl std::fmt::Display for SuricataVersion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}.{}", self.major, self.minor, self.patch) + } +} + +impl SuricataVersion { + fn new(major: u8, minor: u8, patch: u8) -> Self { + Self { + major, + minor, + patch, + } + } +} + +/// Parse a version expression. +/// +/// Parse into a version expression into a nested array, for example: +/// +/// version: >= 7.0.3 < 8 | >= 8.0.3 +/// +/// would result in something like: +/// +/// [ +/// [{op: gte, version: 7.0.3}, {op:lt, version: 8}], +/// [{op: gte, version: 8.0.3}], +/// ] +fn parse_version_expression(input: &str) -> IResult<&str, Vec<Vec<RuleRequireVersion>>> { + let sep = preceded(multispace0, tag("|")); + let inner_parser = many1(tuple((parse_op, parse_version))); + let (input, versions) = separated_list1(sep, inner_parser)(input)?; + + let versions = versions + .into_iter() + .map(|versions| { + versions + .into_iter() + .map(|(op, version)| RuleRequireVersion { op, version }) + .collect() + }) + .collect(); + + Ok((input, versions)) +} + +#[derive(Debug, Eq, PartialEq)] +struct RuleRequireVersion { + pub op: VersionCompareOp, + pub version: SuricataVersion, +} + +#[derive(Debug, Default, Eq, PartialEq)] +struct Requires { + pub features: Vec<String>, + + /// The version expression. + /// + /// - All of the inner most must evaluate to true. + /// - To pass, any of the outer must be true. + pub version: Vec<Vec<RuleRequireVersion>>, +} + +fn parse_op(input: &str) -> IResult<&str, VersionCompareOp> { + preceded( + multispace0, + alt(( + map(tag(">="), |_| VersionCompareOp::Gte), + map(tag(">"), |_| VersionCompareOp::Gt), + map(tag("<="), |_| VersionCompareOp::Lte), + map(tag("<"), |_| VersionCompareOp::Lt), + )), + )(input) +} + +/// Parse the next part of the version. +/// +/// That is all chars up to eof, or the next '.' or '-'. +fn parse_next_version_part(input: &str) -> IResult<&str, u8> { + map_res( + take_till(|c| c == '.' || c == '-' || c == ' '), + |s: &str| s.parse::<u8>(), + )(input) +} + +/// Parse a version string into a SuricataVersion. +fn parse_version(input: &str) -> IResult<&str, SuricataVersion> { + let (input, major) = preceded(multispace0, parse_next_version_part)(input)?; + let (input, minor) = if input.is_empty() || input.starts_with(' ') { + (input, 0) + } else { + preceded(char('.'), parse_next_version_part)(input)? + }; + let (input, patch) = if input.is_empty() || input.starts_with(' ') { + (input, 0) + } else { + preceded(char('.'), parse_next_version_part)(input)? + }; + + Ok((input, SuricataVersion::new(major, minor, patch))) +} + +fn parse_key_value(input: &str) -> IResult<&str, (&str, &str)> { + // Parse the keyword, any sequence of characters, numbers or "-" or "_". + let (input, key) = preceded( + multispace0, + take_while(|c: char| c.is_alphanumeric() || c == '-' || c == '_'), + )(input)?; + let (input, value) = preceded(multispace0, take_till(|c: char| c == ','))(input)?; + Ok((input, (key, value))) +} + +fn parse_requires(mut input: &str) -> Result<Requires, RequiresError> { + let mut requires = Requires::default(); + + while !input.is_empty() { + let (rest, (keyword, value)) = + parse_key_value(input).map_err(|_| RequiresError::BadRequires)?; + match keyword { + "feature" => { + requires.features.push(value.trim().to_string()); + } + "version" => { + if !requires.version.is_empty() { + return Err(RequiresError::MultipleVersions); + } + let (_, versions) = + parse_version_expression(value).map_err(|_| RequiresError::BadRequires)?; + requires.version = versions; + } + _ => { + // Unknown keyword, allow by warn in case we extend + // this in the future. + SCLogWarning!("Unknown requires keyword: {}", keyword); + } + } + + // No consume any remaining ',' or whitespace. + input = rest.trim_start_matches(|c: char| c == ',' || c.is_whitespace()); + } + Ok(requires) +} + +fn parse_suricata_version(version: &CStr) -> Result<SuricataVersion, *const c_char> { + let version = version + .to_str() + .map_err(|_| RequiresError::BadSuricataVersion.c_errmsg())?; + let (_, version) = + parse_version(version).map_err(|_| RequiresError::BadSuricataVersion.c_errmsg())?; + Ok(version) +} + +fn check_version( + version: &RuleRequireVersion, suricata_version: &SuricataVersion, +) -> Result<(), RequiresError> { + match version.op { + VersionCompareOp::Gt => { + if suricata_version <= &version.version { + return Err(RequiresError::VersionLt(version.version.clone())); + } + } + VersionCompareOp::Gte => { + if suricata_version < &version.version { + return Err(RequiresError::VersionLt(version.version.clone())); + } + } + VersionCompareOp::Lt => { + if suricata_version >= &version.version { + return Err(RequiresError::VersionGt); + } + } + VersionCompareOp::Lte => { + if suricata_version > &version.version { + return Err(RequiresError::VersionGt); + } + } + } + Ok(()) +} + +fn check_requires( + requires: &Requires, suricata_version: &SuricataVersion, +) -> Result<(), RequiresError> { + if !requires.version.is_empty() { + let mut errs = VecDeque::new(); + let mut ok = 0; + for or_versions in &requires.version { + let mut err = None; + for version in or_versions { + if let Err(_err) = check_version(version, suricata_version) { + err = Some(_err); + break; + } + } + if let Some(err) = err { + errs.push_back(err); + } else { + ok += 1; + } + } + if ok == 0 { + return Err(errs.pop_front().unwrap()); + } + } + + for feature in &requires.features { + if !crate::feature::requires(feature) { + return Err(RequiresError::MissingFeature(feature.to_string())); + } + } + + Ok(()) +} + +/// Status object to hold required features and the latest version of +/// Suricata required. +/// +/// Full qualified name as it is exposed to C. +#[derive(Debug, Default)] +pub struct SCDetectRequiresStatus { + min_version: Option<SuricataVersion>, + features: HashSet<String>, + + /// Number of rules that didn't meet a feature. + feature_count: u64, + + /// Number of rules where the Suricata version wasn't new enough. + lt_count: u64, + + /// Number of rules where the Suricata version was too new. + gt_count: u64, +} + +#[no_mangle] +pub extern "C" fn SCDetectRequiresStatusNew() -> *mut SCDetectRequiresStatus { + Box::into_raw(Box::default()) +} + +#[no_mangle] +pub unsafe extern "C" fn SCDetectRequiresStatusFree(status: *mut SCDetectRequiresStatus) { + if !status.is_null() { + std::mem::drop(Box::from_raw(status)); + } +} + +#[no_mangle] +pub unsafe extern "C" fn SCDetectRequiresStatusLog( + status: &mut SCDetectRequiresStatus, suricata_version: *const c_char, tenant_id: u32, +) { + let suricata_version = CStr::from_ptr(suricata_version) + .to_str() + .unwrap_or("<unknown>"); + + let mut parts = vec![]; + if status.lt_count > 0 { + let min_version = status + .min_version + .as_ref() + .map(|v| v.to_string()) + .unwrap_or_else(|| "<unknown>".to_string()); + let msg = format!( + "{} {} skipped because the running Suricata version {} is less than {}", + status.lt_count, + if status.lt_count > 1 { + "rules were" + } else { + "rule was" + }, + suricata_version, + &min_version + ); + parts.push(msg); + } + if status.gt_count > 0 { + let msg = format!( + "{} {} for an older version Suricata", + status.gt_count, + if status.gt_count > 1 { + "rules were skipped as they are" + } else { + "rule was skipped as it is" + } + ); + parts.push(msg); + } + if status.feature_count > 0 { + let features = status + .features + .iter() + .map(|f| f.to_string()) + .collect::<Vec<String>>() + .join(", "); + let msg = format!( + "{}{} {} skipped because the running Suricata version does not have feature{}: [{}]", + if tenant_id > 0 { + format!("tenant id: {} ", tenant_id) + } else { + String::new() + }, + status.feature_count, + if status.feature_count > 1 { + "rules were" + } else { + "rule was" + }, + if status.feature_count > 1 { "s" } else { "" }, + &features + ); + parts.push(msg); + } + + let msg = parts.join("; "); + + if status.lt_count > 0 { + SCLogNotice!("{}", &msg); + } else if status.gt_count > 0 || status.feature_count > 0 { + SCLogInfo!("{}", &msg); + } +} + +/// Parse a "requires" rule option. +/// +/// Return values: +/// * 0 - OK, rule should continue loading +/// * -1 - Error parsing the requires content +/// * -4 - Requirements not met, don't continue loading the rule, this +/// value is chosen so it can be passed back to the options parser +/// as its treated as a non-fatal silent error. +#[no_mangle] +pub unsafe extern "C" fn SCDetectCheckRequires( + requires: *const c_char, suricata_version_string: *const c_char, errstr: *mut *const c_char, + status: &mut SCDetectRequiresStatus, +) -> c_int { + // First parse the running Suricata version. + let suricata_version = match parse_suricata_version(CStr::from_ptr(suricata_version_string)) { + Ok(version) => version, + Err(err) => { + *errstr = err; + return -1; + } + }; + + let requires = match CStr::from_ptr(requires) + .to_str() + .map_err(|_| RequiresError::Utf8Error) + .and_then(parse_requires) + { + Ok(requires) => requires, + Err(err) => { + *errstr = err.c_errmsg(); + return -1; + } + }; + + match check_requires(&requires, &suricata_version) { + Ok(()) => 0, + Err(err) => { + match &err { + RequiresError::VersionLt(version) => { + if let Some(min_version) = &status.min_version { + if version > min_version { + status.min_version = Some(version.clone()); + } + } else { + status.min_version = Some(version.clone()); + } + status.lt_count += 1; + } + RequiresError::MissingFeature(feature) => { + status.features.insert(feature.to_string()); + status.feature_count += 1; + } + RequiresError::VersionGt => { + status.gt_count += 1; + } + _ => {} + } + *errstr = err.c_errmsg(); + return -4; + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_suricata_version() { + // 7.1.1 < 7.1.2 + assert!(SuricataVersion::new(7, 1, 1) < SuricataVersion::new(7, 1, 2)); + + // 7.1.1 <= 7.1.2 + assert!(SuricataVersion::new(7, 1, 1) <= SuricataVersion::new(7, 1, 2)); + + // 7.1.1 <= 7.1.1 + assert!(SuricataVersion::new(7, 1, 1) <= SuricataVersion::new(7, 1, 1)); + + // NOT 7.1.1 < 7.1.1 + assert!(SuricataVersion::new(7, 1, 1) >= SuricataVersion::new(7, 1, 1)); + + // 7.3.1 < 7.22.1 + assert!(SuricataVersion::new(7, 3, 1) < SuricataVersion::new(7, 22, 1)); + + // 7.22.1 >= 7.3.4 + assert!(SuricataVersion::new(7, 22, 1) >= SuricataVersion::new(7, 3, 4)); + } + + #[test] + fn test_parse_op() { + assert_eq!(parse_op(">").unwrap().1, VersionCompareOp::Gt); + assert_eq!(parse_op(">=").unwrap().1, VersionCompareOp::Gte); + assert_eq!(parse_op("<").unwrap().1, VersionCompareOp::Lt); + assert_eq!(parse_op("<=").unwrap().1, VersionCompareOp::Lte); + + assert!(parse_op("=").is_err()); + } + + #[test] + fn test_parse_version() { + assert_eq!( + parse_version("7").unwrap().1, + SuricataVersion { + major: 7, + minor: 0, + patch: 0, + } + ); + + assert_eq!( + parse_version("7.1").unwrap().1, + SuricataVersion { + major: 7, + minor: 1, + patch: 0, + } + ); + + assert_eq!( + parse_version("7.1.2").unwrap().1, + SuricataVersion { + major: 7, + minor: 1, + patch: 2, + } + ); + + // Suricata pre-releases will have a suffix starting with a + // '-', so make sure we accept those versions as well. + assert_eq!( + parse_version("8.0.0-dev").unwrap().1, + SuricataVersion { + major: 8, + minor: 0, + patch: 0, + } + ); + + assert!(parse_version("7.1.2a").is_err()); + assert!(parse_version("a").is_err()); + assert!(parse_version("777").is_err()); + assert!(parse_version("product-1").is_err()); + } + + #[test] + fn test_parse_requires() { + let requires = parse_requires(" feature geoip").unwrap(); + assert_eq!(&requires.features[0], "geoip"); + + let requires = parse_requires(" feature geoip, feature lua ").unwrap(); + assert_eq!(&requires.features[0], "geoip"); + assert_eq!(&requires.features[1], "lua"); + + let requires = parse_requires("version >=7").unwrap(); + assert_eq!( + requires, + Requires { + features: vec![], + version: vec![vec![RuleRequireVersion { + op: VersionCompareOp::Gte, + version: SuricataVersion { + major: 7, + minor: 0, + patch: 0, + } + }]], + } + ); + + let requires = parse_requires("version >= 7.1").unwrap(); + assert_eq!( + requires, + Requires { + features: vec![], + version: vec![vec![RuleRequireVersion { + op: VersionCompareOp::Gte, + version: SuricataVersion { + major: 7, + minor: 1, + patch: 0, + } + }]], + } + ); + + let requires = parse_requires("feature output::file-store, version >= 7.1.2").unwrap(); + assert_eq!( + requires, + Requires { + features: vec!["output::file-store".to_string()], + version: vec![vec![RuleRequireVersion { + op: VersionCompareOp::Gte, + version: SuricataVersion { + major: 7, + minor: 1, + patch: 2, + } + }]], + } + ); + + let requires = parse_requires("feature geoip, version >= 7.1.2 < 8").unwrap(); + assert_eq!( + requires, + Requires { + features: vec!["geoip".to_string()], + version: vec![vec![ + RuleRequireVersion { + op: VersionCompareOp::Gte, + version: SuricataVersion { + major: 7, + minor: 1, + patch: 2, + }, + }, + RuleRequireVersion { + op: VersionCompareOp::Lt, + version: SuricataVersion { + major: 8, + minor: 0, + patch: 0, + } + } + ]], + } + ); + } + + #[test] + fn test_check_requires() { + // Have 7.0.4, require >= 8. + let suricata_version = SuricataVersion::new(7, 0, 4); + let requires = parse_requires("version >= 8").unwrap(); + assert_eq!( + check_requires(&requires, &suricata_version), + Err(RequiresError::VersionLt(SuricataVersion { + major: 8, + minor: 0, + patch: 0, + })), + ); + + // Have 7.0.4, require 7.0.3. + let suricata_version = SuricataVersion::new(7, 0, 4); + let requires = parse_requires("version >= 7.0.3").unwrap(); + assert_eq!(check_requires(&requires, &suricata_version), Ok(())); + + // Have 8.0.0, require >= 7.0.0 and < 8.0 + let suricata_version = SuricataVersion::new(8, 0, 0); + let requires = parse_requires("version >= 7.0.0 < 8").unwrap(); + assert_eq!( + check_requires(&requires, &suricata_version), + Err(RequiresError::VersionGt) + ); + + // Have 8.0.0, require >= 7.0.0 and < 9.0 + let suricata_version = SuricataVersion::new(8, 0, 0); + let requires = parse_requires("version >= 7.0.0 < 9").unwrap(); + assert_eq!(check_requires(&requires, &suricata_version), Ok(())); + + // Require feature foobar. + let suricata_version = SuricataVersion::new(8, 0, 0); + let requires = parse_requires("feature foobar").unwrap(); + assert_eq!( + check_requires(&requires, &suricata_version), + Err(RequiresError::MissingFeature("foobar".to_string())) + ); + + // Require feature foobar, but this time we have the feature. + let suricata_version = SuricataVersion::new(8, 0, 0); + let requires = parse_requires("feature true_foobar").unwrap(); + assert_eq!(check_requires(&requires, &suricata_version), Ok(())); + + let suricata_version = SuricataVersion::new(8, 0, 1); + let requires = parse_requires("version >= 7.0.3 < 8").unwrap(); + assert!(check_requires(&requires, &suricata_version).is_err()); + + let suricata_version = SuricataVersion::new(7, 0, 1); + let requires = parse_requires("version >= 7.0.3 < 8").unwrap(); + assert!(check_requires(&requires, &suricata_version).is_err()); + + let suricata_version = SuricataVersion::new(7, 0, 3); + let requires = parse_requires("version >= 7.0.3 < 8").unwrap(); + assert!(check_requires(&requires, &suricata_version).is_ok()); + + let suricata_version = SuricataVersion::new(8, 0, 3); + let requires = parse_requires("version >= 7.0.3 < 8 | >= 8.0.3").unwrap(); + assert!(check_requires(&requires, &suricata_version).is_ok()); + + let suricata_version = SuricataVersion::new(8, 0, 2); + let requires = parse_requires("version >= 7.0.3 < 8 | >= 8.0.3").unwrap(); + assert!(check_requires(&requires, &suricata_version).is_err()); + + let suricata_version = SuricataVersion::new(7, 0, 2); + let requires = parse_requires("version >= 7.0.3 < 8 | >= 8.0.3").unwrap(); + assert!(check_requires(&requires, &suricata_version).is_err()); + + let suricata_version = SuricataVersion::new(7, 0, 3); + let requires = parse_requires("version >= 7.0.3 < 8 | >= 8.0.3").unwrap(); + assert!(check_requires(&requires, &suricata_version).is_ok()); + + // Example of something that requires a fix/feature that was + // implemented in 7.0.5, 8.0.4, 9.0.3. + let requires = parse_requires("version >= 7.0.5 < 8 | >= 8.0.4 < 9 | >= 9.0.3").unwrap(); + assert!(check_requires(&requires, &SuricataVersion::new(6, 0, 0)).is_err()); + assert!(check_requires(&requires, &SuricataVersion::new(7, 0, 4)).is_err()); + assert!(check_requires(&requires, &SuricataVersion::new(7, 0, 5)).is_ok()); + assert!(check_requires(&requires, &SuricataVersion::new(8, 0, 3)).is_err()); + assert!(check_requires(&requires, &SuricataVersion::new(8, 0, 4)).is_ok()); + assert!(check_requires(&requires, &SuricataVersion::new(9, 0, 2)).is_err()); + assert!(check_requires(&requires, &SuricataVersion::new(9, 0, 3)).is_ok()); + assert!(check_requires(&requires, &SuricataVersion::new(10, 0, 0)).is_ok()); + + let requires = parse_requires("version >= 8 < 9").unwrap(); + assert!(check_requires(&requires, &SuricataVersion::new(6, 0, 0)).is_err()); + assert!(check_requires(&requires, &SuricataVersion::new(7, 0, 0)).is_err()); + assert!(check_requires(&requires, &SuricataVersion::new(8, 0, 0)).is_ok()); + assert!(check_requires(&requires, &SuricataVersion::new(9, 0, 0)).is_err()); + + // Unknown keyword. + let requires = parse_requires("feature lua, foo bar, version >= 7.0.3").unwrap(); + assert_eq!( + requires, + Requires { + features: vec!["lua".to_string()], + version: vec![vec![RuleRequireVersion { + op: VersionCompareOp::Gte, + version: SuricataVersion { + major: 7, + minor: 0, + patch: 3, + } + }]], + } + ); + } + + #[test] + fn test_parse_version_expression() { + let version_str = ">= 7.0.3 < 8 | >= 8.0.3"; + let (rest, versions) = parse_version_expression(version_str).unwrap(); + assert!(rest.is_empty()); + assert_eq!( + versions, + vec![ + vec![ + RuleRequireVersion { + op: VersionCompareOp::Gte, + version: SuricataVersion { + major: 7, + minor: 0, + patch: 3, + } + }, + RuleRequireVersion { + op: VersionCompareOp::Lt, + version: SuricataVersion { + major: 8, + minor: 0, + patch: 0, + } + }, + ], + vec![RuleRequireVersion { + op: VersionCompareOp::Gte, + version: SuricataVersion { + major: 8, + minor: 0, + patch: 3, + } + },], + ] + ); + } +} diff --git a/rust/src/detect/stream_size.rs b/rust/src/detect/stream_size.rs new file mode 100644 index 0000000..cb8c826 --- /dev/null +++ b/rust/src/detect/stream_size.rs @@ -0,0 +1,98 @@ +/* Copyright (C) 2022 Open Information Security Foundation + * + * You can copy, redistribute or modify this Program under the terms of + * the GNU General Public License version 2 as published by the Free + * Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * version 2 along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA + * 02110-1301, USA. + */ + +use super::uint::*; +use nom7::bytes::complete::{is_a, take_while}; +use nom7::character::complete::{alpha0, char, digit1}; +use nom7::combinator::{all_consuming, map_opt, map_res, opt}; +use nom7::IResult; + +use std::ffi::CStr; +use std::str::FromStr; + +#[repr(u8)] +#[derive(Clone, Copy, PartialEq, Eq, FromPrimitive, Debug)] +pub enum DetectStreamSizeDataFlags { + StreamSizeServer = 1, + StreamSizeClient = 2, + StreamSizeBoth = 3, + StreamSizeEither = 4, +} + +impl std::str::FromStr for DetectStreamSizeDataFlags { + type Err = String; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + match s { + "server" => Ok(DetectStreamSizeDataFlags::StreamSizeServer), + "client" => Ok(DetectStreamSizeDataFlags::StreamSizeClient), + "both" => Ok(DetectStreamSizeDataFlags::StreamSizeBoth), + "either" => Ok(DetectStreamSizeDataFlags::StreamSizeEither), + _ => Err(format!( + "'{}' is not a valid value for DetectStreamSizeDataFlags", + s + )), + } + } +} + +#[derive(Debug)] +#[repr(C)] +pub struct DetectStreamSizeData { + pub flags: DetectStreamSizeDataFlags, + pub du32: DetectUintData<u32>, +} + +pub fn detect_parse_stream_size(i: &str) -> IResult<&str, DetectStreamSizeData> { + let (i, _) = opt(is_a(" "))(i)?; + let (i, flags) = map_res(alpha0, DetectStreamSizeDataFlags::from_str)(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, _) = char(',')(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, mode) = detect_parse_uint_mode(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, _) = char(',')(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, arg1) = map_opt(digit1, |s: &str| s.parse::<u32>().ok())(i)?; + let (i, _) = all_consuming(take_while(|c| c == ' '))(i)?; + let du32 = DetectUintData::<u32> { + arg1, + arg2: 0, + mode, + }; + Ok((i, DetectStreamSizeData { flags, du32 })) +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_stream_size_parse( + ustr: *const std::os::raw::c_char, +) -> *mut DetectStreamSizeData { + let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe + if let Ok(s) = ft_name.to_str() { + if let Ok((_, ctx)) = detect_parse_stream_size(s) { + let boxed = Box::new(ctx); + return Box::into_raw(boxed) as *mut _; + } + } + return std::ptr::null_mut(); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_stream_size_free(ctx: &mut DetectStreamSizeData) { + // Just unbox... + std::mem::drop(Box::from_raw(ctx)); +} diff --git a/rust/src/detect/uint.rs b/rust/src/detect/uint.rs new file mode 100644 index 0000000..3d6a5ba --- /dev/null +++ b/rust/src/detect/uint.rs @@ -0,0 +1,435 @@ +/* Copyright (C) 2022 Open Information Security Foundation + * + * You can copy, redistribute or modify this Program under the terms of + * the GNU General Public License version 2 as published by the Free + * Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * version 2 along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA + * 02110-1301, USA. + */ + +use nom7::branch::alt; +use nom7::bytes::complete::{is_a, tag, tag_no_case, take_while}; +use nom7::character::complete::digit1; +use nom7::combinator::{all_consuming, map_opt, opt, value, verify}; +use nom7::error::{make_error, ErrorKind}; +use nom7::Err; +use nom7::IResult; + +use std::ffi::CStr; + +#[derive(PartialEq, Eq, Clone, Debug)] +#[repr(u8)] +pub enum DetectUintMode { + DetectUintModeEqual, + DetectUintModeLt, + DetectUintModeLte, + DetectUintModeGt, + DetectUintModeGte, + DetectUintModeRange, + DetectUintModeNe, +} + +#[derive(Debug)] +#[repr(C)] +pub struct DetectUintData<T> { + pub arg1: T, + pub arg2: T, + pub mode: DetectUintMode, +} + +pub trait DetectIntType: + std::str::FromStr + + std::cmp::PartialOrd + + num::PrimInt + + num::Bounded + + num::ToPrimitive + + num::FromPrimitive +{ +} +impl<T> DetectIntType for T where + T: std::str::FromStr + + std::cmp::PartialOrd + + num::PrimInt + + num::Bounded + + num::ToPrimitive + + num::FromPrimitive +{ +} + +pub fn detect_parse_uint_unit(i: &str) -> IResult<&str, u64> { + let (i, unit) = alt(( + value(1024, tag_no_case("kb")), + value(1024 * 1024, tag_no_case("mb")), + value(1024 * 1024 * 1024, tag_no_case("gb")), + ))(i)?; + return Ok((i, unit)); +} + +pub fn detect_parse_uint_with_unit<T: DetectIntType>(i: &str) -> IResult<&str, T> { + let (i, arg1) = map_opt(digit1, |s: &str| s.parse::<T>().ok())(i)?; + let (i, unit) = opt(detect_parse_uint_unit)(i)?; + if arg1 >= T::one() { + if let Some(u) = unit { + if T::max_value().to_u64().unwrap() / u < arg1.to_u64().unwrap() { + return Err(Err::Error(make_error(i, ErrorKind::Verify))); + } + let ru64 = arg1 * T::from_u64(u).unwrap(); + return Ok((i, ru64)); + } + } + Ok((i, arg1)) +} + +pub fn detect_parse_uint_start_equal<T: DetectIntType>( + i: &str, +) -> IResult<&str, DetectUintData<T>> { + let (i, _) = opt(tag("="))(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, arg1) = detect_parse_uint_with_unit(i)?; + Ok(( + i, + DetectUintData { + arg1, + arg2: T::min_value(), + mode: DetectUintMode::DetectUintModeEqual, + }, + )) +} + +pub fn detect_parse_uint_start_interval<T: DetectIntType>( + i: &str, +) -> IResult<&str, DetectUintData<T>> { + let (i, arg1) = map_opt(digit1, |s: &str| s.parse::<T>().ok())(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, _) = alt((tag("-"), tag("<>")))(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, arg2) = verify(map_opt(digit1, |s: &str| s.parse::<T>().ok()), |x| { + x > &arg1 && *x - arg1 > T::one() + })(i)?; + Ok(( + i, + DetectUintData { + arg1, + arg2, + mode: DetectUintMode::DetectUintModeRange, + }, + )) +} + +fn detect_parse_uint_start_interval_inclusive<T: DetectIntType>( + i: &str, +) -> IResult<&str, DetectUintData<T>> { + let (i, arg1) = verify(map_opt(digit1, |s: &str| s.parse::<T>().ok()), |x| { + *x > T::min_value() + })(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, _) = alt((tag("-"), tag("<>")))(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, arg2) = verify(map_opt(digit1, |s: &str| s.parse::<T>().ok()), |x| { + *x > arg1 && *x < T::max_value() + })(i)?; + Ok(( + i, + DetectUintData { + arg1: arg1 - T::one(), + arg2: arg2 + T::one(), + mode: DetectUintMode::DetectUintModeRange, + }, + )) +} + +pub fn detect_parse_uint_mode(i: &str) -> IResult<&str, DetectUintMode> { + let (i, mode) = alt(( + value(DetectUintMode::DetectUintModeGte, tag(">=")), + value(DetectUintMode::DetectUintModeLte, tag("<=")), + value(DetectUintMode::DetectUintModeGt, tag(">")), + value(DetectUintMode::DetectUintModeLt, tag("<")), + value(DetectUintMode::DetectUintModeNe, tag("!=")), + value(DetectUintMode::DetectUintModeNe, tag("!")), + value(DetectUintMode::DetectUintModeEqual, tag("=")), + ))(i)?; + return Ok((i, mode)); +} + +fn detect_parse_uint_start_symbol<T: DetectIntType>(i: &str) -> IResult<&str, DetectUintData<T>> { + let (i, mode) = detect_parse_uint_mode(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, arg1) = map_opt(digit1, |s: &str| s.parse::<T>().ok())(i)?; + + match mode { + DetectUintMode::DetectUintModeNe => {} + DetectUintMode::DetectUintModeLt => { + if arg1 == T::min_value() { + return Err(Err::Error(make_error(i, ErrorKind::Verify))); + } + } + DetectUintMode::DetectUintModeLte => { + if arg1 == T::max_value() { + return Err(Err::Error(make_error(i, ErrorKind::Verify))); + } + } + DetectUintMode::DetectUintModeGt => { + if arg1 == T::max_value() { + return Err(Err::Error(make_error(i, ErrorKind::Verify))); + } + } + DetectUintMode::DetectUintModeGte => { + if arg1 == T::min_value() { + return Err(Err::Error(make_error(i, ErrorKind::Verify))); + } + } + _ => { + return Err(Err::Error(make_error(i, ErrorKind::MapOpt))); + } + } + + Ok(( + i, + DetectUintData { + arg1, + arg2: T::min_value(), + mode, + }, + )) +} + +pub fn detect_match_uint<T: DetectIntType>(x: &DetectUintData<T>, val: T) -> bool { + match x.mode { + DetectUintMode::DetectUintModeEqual => { + if val == x.arg1 { + return true; + } + } + DetectUintMode::DetectUintModeNe => { + if val != x.arg1 { + return true; + } + } + DetectUintMode::DetectUintModeLt => { + if val < x.arg1 { + return true; + } + } + DetectUintMode::DetectUintModeLte => { + if val <= x.arg1 { + return true; + } + } + DetectUintMode::DetectUintModeGt => { + if val > x.arg1 { + return true; + } + } + DetectUintMode::DetectUintModeGte => { + if val >= x.arg1 { + return true; + } + } + DetectUintMode::DetectUintModeRange => { + if val > x.arg1 && val < x.arg2 { + return true; + } + } + } + return false; +} + +pub fn detect_parse_uint_notending<T: DetectIntType>(i: &str) -> IResult<&str, DetectUintData<T>> { + let (i, _) = opt(is_a(" "))(i)?; + let (i, uint) = alt(( + detect_parse_uint_start_interval, + detect_parse_uint_start_equal, + detect_parse_uint_start_symbol, + ))(i)?; + Ok((i, uint)) +} + +pub fn detect_parse_uint<T: DetectIntType>(i: &str) -> IResult<&str, DetectUintData<T>> { + let (i, uint) = detect_parse_uint_notending(i)?; + let (i, _) = all_consuming(take_while(|c| c == ' '))(i)?; + Ok((i, uint)) +} + +pub fn detect_parse_uint_inclusive<T: DetectIntType>(i: &str) -> IResult<&str, DetectUintData<T>> { + let (i, _) = opt(is_a(" "))(i)?; + let (i, uint) = alt(( + detect_parse_uint_start_interval_inclusive, + detect_parse_uint_start_equal, + detect_parse_uint_start_symbol, + ))(i)?; + let (i, _) = all_consuming(take_while(|c| c == ' '))(i)?; + Ok((i, uint)) +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u64_parse( + ustr: *const std::os::raw::c_char, +) -> *mut DetectUintData<u64> { + let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe + if let Ok(s) = ft_name.to_str() { + if let Ok((_, ctx)) = detect_parse_uint::<u64>(s) { + let boxed = Box::new(ctx); + return Box::into_raw(boxed) as *mut _; + } + } + return std::ptr::null_mut(); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u64_match( + arg: u64, ctx: &DetectUintData<u64>, +) -> std::os::raw::c_int { + if detect_match_uint(ctx, arg) { + return 1; + } + return 0; +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u64_free(ctx: *mut std::os::raw::c_void) { + // Just unbox... + std::mem::drop(Box::from_raw(ctx as *mut DetectUintData<u64>)); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u32_parse( + ustr: *const std::os::raw::c_char, +) -> *mut DetectUintData<u32> { + let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe + if let Ok(s) = ft_name.to_str() { + if let Ok((_, ctx)) = detect_parse_uint::<u32>(s) { + let boxed = Box::new(ctx); + return Box::into_raw(boxed) as *mut _; + } + } + return std::ptr::null_mut(); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u32_parse_inclusive( + ustr: *const std::os::raw::c_char, +) -> *mut DetectUintData<u32> { + let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe + if let Ok(s) = ft_name.to_str() { + if let Ok((_, ctx)) = detect_parse_uint_inclusive::<u32>(s) { + let boxed = Box::new(ctx); + return Box::into_raw(boxed) as *mut _; + } + } + return std::ptr::null_mut(); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u32_match( + arg: u32, ctx: &DetectUintData<u32>, +) -> std::os::raw::c_int { + if detect_match_uint(ctx, arg) { + return 1; + } + return 0; +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u32_free(ctx: &mut DetectUintData<u32>) { + // Just unbox... + std::mem::drop(Box::from_raw(ctx)); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u8_parse( + ustr: *const std::os::raw::c_char, +) -> *mut DetectUintData<u8> { + let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe + if let Ok(s) = ft_name.to_str() { + if let Ok((_, ctx)) = detect_parse_uint::<u8>(s) { + let boxed = Box::new(ctx); + return Box::into_raw(boxed) as *mut _; + } + } + return std::ptr::null_mut(); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u8_match( + arg: u8, ctx: &DetectUintData<u8>, +) -> std::os::raw::c_int { + if detect_match_uint(ctx, arg) { + return 1; + } + return 0; +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u8_free(ctx: &mut DetectUintData<u8>) { + // Just unbox... + std::mem::drop(Box::from_raw(ctx)); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u16_parse( + ustr: *const std::os::raw::c_char, +) -> *mut DetectUintData<u16> { + let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe + if let Ok(s) = ft_name.to_str() { + if let Ok((_, ctx)) = detect_parse_uint::<u16>(s) { + let boxed = Box::new(ctx); + return Box::into_raw(boxed) as *mut _; + } + } + return std::ptr::null_mut(); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u16_match( + arg: u16, ctx: &DetectUintData<u16>, +) -> std::os::raw::c_int { + if detect_match_uint(ctx, arg) { + return 1; + } + return 0; +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u16_free(ctx: &mut DetectUintData<u16>) { + // Just unbox... + std::mem::drop(Box::from_raw(ctx)); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_uint_unit() { + match detect_parse_uint::<u64>(" 2kb") { + Ok((_, val)) => { + assert_eq!(val.arg1, 2048); + } + Err(_) => { + assert!(false); + } + } + match detect_parse_uint::<u8>("2kb") { + Ok((_, _val)) => { + assert!(false); + } + Err(_) => {} + } + match detect_parse_uint::<u32>("3MB") { + Ok((_, val)) => { + assert_eq!(val.arg1, 3 * 1024 * 1024); + } + Err(_) => { + assert!(false); + } + } + } +} diff --git a/rust/src/detect/uri.rs b/rust/src/detect/uri.rs new file mode 100644 index 0000000..ae98278 --- /dev/null +++ b/rust/src/detect/uri.rs @@ -0,0 +1,78 @@ +/* Copyright (C) 2022 Open Information Security Foundation + * + * You can copy, redistribute or modify this Program under the terms of + * the GNU General Public License version 2 as published by the Free + * Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * version 2 along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA + * 02110-1301, USA. + */ + +use super::uint::*; +use nom7::branch::alt; +use nom7::bytes::complete::{is_a, tag}; +use nom7::character::complete::char; +use nom7::combinator::{opt, value}; +use nom7::IResult; + +use std::ffi::CStr; + +#[derive(Debug)] +#[repr(C)] +pub struct DetectUrilenData { + pub du16: DetectUintData<u16>, + pub raw_buffer: bool, +} + +pub fn detect_parse_urilen_raw(i: &str) -> IResult<&str, bool> { + let (i, _) = opt(is_a(" "))(i)?; + let (i, _) = char(',')(i)?; + let (i, _) = opt(is_a(" "))(i)?; + return alt((value(true, tag("raw")), value(false, tag("norm"))))(i); +} + +pub fn detect_parse_urilen(i: &str) -> IResult<&str, DetectUrilenData> { + let (i, du16) = detect_parse_uint_notending::<u16>(i)?; + let (i, raw) = opt(detect_parse_urilen_raw)(i)?; + match raw { + Some(raw_buffer) => { + return Ok((i, DetectUrilenData { du16, raw_buffer })); + } + None => { + return Ok(( + i, + DetectUrilenData { + du16, + raw_buffer: false, + }, + )); + } + } +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_urilen_parse( + ustr: *const std::os::raw::c_char, +) -> *mut DetectUrilenData { + let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe + if let Ok(s) = ft_name.to_str() { + if let Ok((_, ctx)) = detect_parse_urilen(s) { + let boxed = Box::new(ctx); + return Box::into_raw(boxed) as *mut _; + } + } + return std::ptr::null_mut(); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_urilen_free(ctx: &mut DetectUrilenData) { + // Just unbox... + std::mem::drop(Box::from_raw(ctx)); +} |