From 948a422be120c069e48c63a8770fec7204307897 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Tue, 19 Dec 2023 12:01:36 +0100 Subject: Adding upstream version 20.3.0. Signed-off-by: Daniel Baumann --- sqlglotrs/src/lib.rs | 86 ++++++ sqlglotrs/src/settings.rs | 164 +++++++++++ sqlglotrs/src/tokenizer.rs | 670 +++++++++++++++++++++++++++++++++++++++++++++ sqlglotrs/src/trie.rs | 68 +++++ 4 files changed, 988 insertions(+) create mode 100644 sqlglotrs/src/lib.rs create mode 100644 sqlglotrs/src/settings.rs create mode 100644 sqlglotrs/src/tokenizer.rs create mode 100644 sqlglotrs/src/trie.rs (limited to 'sqlglotrs/src') diff --git a/sqlglotrs/src/lib.rs b/sqlglotrs/src/lib.rs new file mode 100644 index 0000000..c962887 --- /dev/null +++ b/sqlglotrs/src/lib.rs @@ -0,0 +1,86 @@ +use pyo3::prelude::*; +use pyo3::types::{PyList, PyNone, PyString}; + +mod settings; +mod tokenizer; +mod trie; + +pub use self::settings::{ + TokenType, TokenTypeSettings, TokenizerDialectSettings, TokenizerSettings, +}; +pub use self::tokenizer::Tokenizer; + +#[derive(Debug)] +#[pyclass] +pub struct Token { + #[pyo3(get, name = "token_type_index")] + pub token_type: TokenType, + #[pyo3(get, set, name = "token_type")] + pub token_type_py: PyObject, + #[pyo3(get)] + pub text: Py, + #[pyo3(get)] + pub line: usize, + #[pyo3(get)] + pub col: usize, + #[pyo3(get)] + pub start: usize, + #[pyo3(get)] + pub end: usize, + #[pyo3(get)] + pub comments: Py, +} + +impl Token { + pub fn new( + token_type: TokenType, + text: String, + line: usize, + col: usize, + start: usize, + end: usize, + comments: Vec, + ) -> Token { + Python::with_gil(|py| Token { + token_type, + token_type_py: PyNone::get(py).into(), + text: PyString::new(py, &text).into(), + line, + col, + start, + end, + comments: PyList::new(py, &comments).into(), + }) + } + + pub fn append_comments(&self, comments: &mut Vec) { + Python::with_gil(|py| { + let pylist = self.comments.as_ref(py); + for comment in comments.iter() { + if let Err(_) = pylist.append(comment) { + panic!("Failed to append comments to the Python list"); + } + } + }); + // Simulate `Vec::append`. + let _ = std::mem::replace(comments, Vec::new()); + } +} + +#[pymethods] +impl Token { + #[pyo3(name = "__repr__")] + fn python_repr(&self) -> PyResult { + Ok(format!("{:?}", self)) + } +} + +#[pymodule] +fn sqlglotrs(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) +} diff --git a/sqlglotrs/src/settings.rs b/sqlglotrs/src/settings.rs new file mode 100644 index 0000000..32575c6 --- /dev/null +++ b/sqlglotrs/src/settings.rs @@ -0,0 +1,164 @@ +use pyo3::prelude::*; +use std::collections::{HashMap, HashSet}; + +pub type TokenType = u16; + +#[derive(Clone, Debug)] +#[pyclass] +pub struct TokenTypeSettings { + pub bit_string: TokenType, + pub break_: TokenType, + pub dcolon: TokenType, + pub heredoc_string: TokenType, + pub hex_string: TokenType, + pub identifier: TokenType, + pub number: TokenType, + pub parameter: TokenType, + pub semicolon: TokenType, + pub string: TokenType, + pub var: TokenType, +} + +#[pymethods] +impl TokenTypeSettings { + #[new] + pub fn new( + bit_string: TokenType, + break_: TokenType, + dcolon: TokenType, + heredoc_string: TokenType, + hex_string: TokenType, + identifier: TokenType, + number: TokenType, + parameter: TokenType, + semicolon: TokenType, + string: TokenType, + var: TokenType, + ) -> Self { + TokenTypeSettings { + bit_string, + break_, + dcolon, + heredoc_string, + hex_string, + identifier, + number, + parameter, + semicolon, + string, + var, + } + } +} + +#[derive(Clone, Debug)] +#[pyclass] +pub struct TokenizerSettings { + pub white_space: HashMap, + pub single_tokens: HashMap, + pub keywords: HashMap, + pub numeric_literals: HashMap, + pub identifiers: HashMap, + pub identifier_escapes: HashSet, + pub string_escapes: HashSet, + pub quotes: HashMap, + pub format_strings: HashMap, + pub has_bit_strings: bool, + pub has_hex_strings: bool, + pub comments: HashMap>, + pub var_single_tokens: HashSet, + pub commands: HashSet, + pub command_prefix_tokens: HashSet, +} + +#[pymethods] +impl TokenizerSettings { + #[new] + pub fn new( + white_space: HashMap, + single_tokens: HashMap, + keywords: HashMap, + numeric_literals: HashMap, + identifiers: HashMap, + identifier_escapes: HashSet, + string_escapes: HashSet, + quotes: HashMap, + format_strings: HashMap, + has_bit_strings: bool, + has_hex_strings: bool, + comments: HashMap>, + var_single_tokens: HashSet, + commands: HashSet, + command_prefix_tokens: HashSet, + ) -> Self { + let to_char = |v: &String| { + if v.len() == 1 { + v.chars().next().unwrap() + } else { + panic!("Invalid char: {}", v) + } + }; + + let white_space_native: HashMap = white_space + .into_iter() + .map(|(k, v)| (to_char(&k), v)) + .collect(); + + let single_tokens_native: HashMap = single_tokens + .into_iter() + .map(|(k, v)| (to_char(&k), v)) + .collect(); + + let identifiers_native: HashMap = identifiers + .iter() + .map(|(k, v)| (to_char(k), to_char(v))) + .collect(); + + let identifier_escapes_native: HashSet = + identifier_escapes.iter().map(&to_char).collect(); + + let string_escapes_native: HashSet = string_escapes.iter().map(&to_char).collect(); + + let var_single_tokens_native: HashSet = + var_single_tokens.iter().map(&to_char).collect(); + + TokenizerSettings { + white_space: white_space_native, + single_tokens: single_tokens_native, + keywords, + numeric_literals, + identifiers: identifiers_native, + identifier_escapes: identifier_escapes_native, + string_escapes: string_escapes_native, + quotes, + format_strings, + has_bit_strings, + has_hex_strings, + comments, + var_single_tokens: var_single_tokens_native, + commands, + command_prefix_tokens, + } + } +} + +#[derive(Clone, Debug)] +#[pyclass] +pub struct TokenizerDialectSettings { + pub escape_sequences: HashMap, + pub identifiers_can_start_with_digit: bool, +} + +#[pymethods] +impl TokenizerDialectSettings { + #[new] + pub fn new( + escape_sequences: HashMap, + identifiers_can_start_with_digit: bool, + ) -> Self { + TokenizerDialectSettings { + escape_sequences, + identifiers_can_start_with_digit, + } + } +} diff --git a/sqlglotrs/src/tokenizer.rs b/sqlglotrs/src/tokenizer.rs new file mode 100644 index 0000000..920a5b5 --- /dev/null +++ b/sqlglotrs/src/tokenizer.rs @@ -0,0 +1,670 @@ +use crate::trie::{Trie, TrieResult}; +use crate::{Token, TokenType, TokenTypeSettings, TokenizerDialectSettings, TokenizerSettings}; +use pyo3::exceptions::PyException; +use pyo3::prelude::*; +use std::cmp::{max, min}; + +#[derive(Debug)] +pub struct TokenizerError { + message: String, + context: String, +} + +#[derive(Debug)] +#[pyclass] +pub struct Tokenizer { + settings: TokenizerSettings, + token_types: TokenTypeSettings, + keyword_trie: Trie, +} + +#[pymethods] +impl Tokenizer { + #[new] + pub fn new(settings: TokenizerSettings, token_types: TokenTypeSettings) -> Tokenizer { + let mut keyword_trie = Trie::new(); + let single_token_strs: Vec = settings + .single_tokens + .keys() + .map(|s| s.to_string()) + .collect(); + let trie_filter = + |key: &&String| key.contains(" ") || single_token_strs.iter().any(|t| key.contains(t)); + + keyword_trie.add(settings.keywords.keys().filter(trie_filter)); + keyword_trie.add(settings.comments.keys().filter(trie_filter)); + keyword_trie.add(settings.quotes.keys().filter(trie_filter)); + keyword_trie.add(settings.format_strings.keys().filter(trie_filter)); + + Tokenizer { + settings, + token_types, + keyword_trie, + } + } + + pub fn tokenize( + &self, + sql: &str, + dialect_settings: &TokenizerDialectSettings, + ) -> Result, PyErr> { + let mut state = TokenizerState::new( + sql, + &self.settings, + &self.token_types, + dialect_settings, + &self.keyword_trie, + ); + state.tokenize().map_err(|e| { + PyException::new_err(format!("Error tokenizing '{}': {}", e.context, e.message)) + }) + } +} + +#[derive(Debug)] +struct TokenizerState<'a> { + sql: Vec, + size: usize, + tokens: Vec, + start: usize, + current: usize, + line: usize, + column: usize, + comments: Vec, + is_end: bool, + current_char: char, + peek_char: char, + previous_token_line: Option, + keyword_trie: &'a Trie, + settings: &'a TokenizerSettings, + dialect_settings: &'a TokenizerDialectSettings, + token_types: &'a TokenTypeSettings, +} + +impl<'a> TokenizerState<'a> { + fn new( + sql: &str, + settings: &'a TokenizerSettings, + token_types: &'a TokenTypeSettings, + dialect_settings: &'a TokenizerDialectSettings, + keyword_trie: &'a Trie, + ) -> TokenizerState<'a> { + let sql_vec = sql.chars().collect::>(); + let sql_vec_len = sql_vec.len(); + TokenizerState { + sql: sql_vec, + size: sql_vec_len, + tokens: Vec::new(), + start: 0, + current: 0, + line: 1, + column: 0, + comments: Vec::new(), + is_end: false, + current_char: '\0', + peek_char: '\0', + previous_token_line: None, + keyword_trie, + settings, + dialect_settings, + token_types, + } + } + + fn tokenize(&mut self) -> Result, TokenizerError> { + self.scan(None)?; + Ok(std::mem::replace(&mut self.tokens, Vec::new())) + } + + fn scan(&mut self, until_peek_char: Option) -> Result<(), TokenizerError> { + while self.size > 0 && !self.is_end { + self.start = self.current; + self.advance(1)?; + + if self.current_char == '\0' { + break; + } + + if !self.settings.white_space.contains_key(&self.current_char) { + if self.current_char.is_digit(10) { + self.scan_number()?; + } else if let Some(identifier_end) = + self.settings.identifiers.get(&self.current_char) + { + self.scan_identifier(&identifier_end.to_string())?; + } else { + self.scan_keyword()?; + } + } + + if let Some(c) = until_peek_char { + if self.peek_char == c { + break; + } + } + } + if !self.tokens.is_empty() && !self.comments.is_empty() { + self.tokens + .last_mut() + .unwrap() + .append_comments(&mut self.comments); + } + Ok(()) + } + + fn advance(&mut self, i: isize) -> Result<(), TokenizerError> { + let mut i = i; + if Some(&self.token_types.break_) == self.settings.white_space.get(&self.current_char) { + // Ensures we don't count an extra line if we get a \r\n line break sequence. + if self.current_char == '\r' && self.peek_char == '\n' { + i = 2; + self.start += 1; + } + + self.column = 1; + self.line += 1; + } else { + self.column = self.column.wrapping_add_signed(i); + } + + self.current = self.current.wrapping_add_signed(i); + self.is_end = self.current >= self.size; + self.current_char = self.char_at(self.current - 1)?; + self.peek_char = if self.is_end { + '\0' + } else { + self.char_at(self.current)? + }; + Ok(()) + } + + fn peek(&self, i: usize) -> Result { + let index = self.current + i; + if index < self.size { + self.char_at(index) + } else { + Ok('\0') + } + } + + fn chars(&self, size: usize) -> String { + let start = self.current - 1; + let end = start + size; + if end <= self.size { + self.sql[start..end].iter().collect() + } else { + String::from("") + } + } + + fn char_at(&self, index: usize) -> Result { + self.sql.get(index).map(|c| *c).ok_or_else(|| { + self.error(format!( + "Index {} is out of bound (size {})", + index, self.size + )) + }) + } + + fn text(&self) -> String { + self.sql[self.start..self.current].iter().collect() + } + + fn add(&mut self, token_type: TokenType, text: Option) -> Result<(), TokenizerError> { + self.previous_token_line = Some(self.line); + + if !self.comments.is_empty() + && !self.tokens.is_empty() + && token_type == self.token_types.semicolon + { + self.tokens + .last_mut() + .unwrap() + .append_comments(&mut self.comments); + } + + self.tokens.push(Token::new( + token_type, + text.unwrap_or(self.text()), + self.line, + self.column, + self.start, + self.current - 1, + std::mem::replace(&mut self.comments, Vec::new()), + )); + + // If we have either a semicolon or a begin token before the command's token, we'll parse + // whatever follows the command's token as a string. + if self.settings.commands.contains(&token_type) + && self.peek_char != ';' + && (self.tokens.len() == 1 + || self + .settings + .command_prefix_tokens + .contains(&self.tokens[self.tokens.len() - 2].token_type)) + { + let start = self.current; + let tokens_len = self.tokens.len(); + self.scan(Some(';'))?; + self.tokens.truncate(tokens_len); + let text = self.sql[start..self.current] + .iter() + .collect::() + .trim() + .to_string(); + if !text.is_empty() { + self.add(self.token_types.string, Some(text))?; + } + } + Ok(()) + } + + fn scan_keyword(&mut self) -> Result<(), TokenizerError> { + let mut size: usize = 0; + let mut word: Option = None; + let mut chars = self.text(); + let mut current_char = '\0'; + let mut prev_space = false; + let mut skip; + let mut is_single_token = chars.len() == 1 + && self + .settings + .single_tokens + .contains_key(&chars.chars().next().unwrap()); + + let (mut trie_result, mut trie_node) = + self.keyword_trie.root.contains(&chars.to_uppercase()); + + while !chars.is_empty() { + if let TrieResult::Failed = trie_result { + break; + } else if let TrieResult::Exists = trie_result { + word = Some(chars.clone()); + } + + let end = self.current + size; + size += 1; + + if end < self.size { + current_char = self.char_at(end)?; + is_single_token = + is_single_token || self.settings.single_tokens.contains_key(¤t_char); + let is_space = current_char.is_whitespace(); + + if !is_space || !prev_space { + if is_space { + current_char = ' '; + } + chars.push(current_char); + prev_space = is_space; + skip = false; + } else { + skip = true; + } + } else { + current_char = '\0'; + break; + } + + if skip { + trie_result = TrieResult::Prefix; + } else { + (trie_result, trie_node) = + trie_node.contains(¤t_char.to_uppercase().collect::()); + } + } + + if let Some(unwrapped_word) = word { + if self.scan_string(&unwrapped_word)? { + return Ok(()); + } + if self.scan_comment(&unwrapped_word)? { + return Ok(()); + } + if prev_space || is_single_token || current_char == '\0' { + self.advance((size - 1) as isize)?; + let normalized_word = unwrapped_word.to_uppercase(); + let keyword_token = + *self + .settings + .keywords + .get(&normalized_word) + .ok_or_else(|| { + self.error(format!("Unexpected keyword '{}'", &normalized_word)) + })?; + self.add(keyword_token, Some(unwrapped_word))?; + return Ok(()); + } + } + + match self.settings.single_tokens.get(&self.current_char) { + Some(token_type) => self.add(*token_type, Some(self.current_char.to_string())), + None => self.scan_var(), + } + } + + fn scan_comment(&mut self, comment_start: &str) -> Result { + if !self.settings.comments.contains_key(comment_start) { + return Ok(false); + } + + let comment_start_line = self.line; + let comment_start_size = comment_start.len(); + + if let Some(comment_end) = self.settings.comments.get(comment_start).unwrap() { + // Skip the comment's start delimiter. + self.advance(comment_start_size as isize)?; + + let comment_end_size = comment_end.len(); + + while !self.is_end && self.chars(comment_end_size) != *comment_end { + self.advance(1)?; + } + + let text = self.text(); + self.comments + .push(text[comment_start_size..text.len() - comment_end_size + 1].to_string()); + self.advance((comment_end_size - 1) as isize)?; + } else { + while !self.is_end + && self.settings.white_space.get(&self.peek_char) != Some(&self.token_types.break_) + { + self.advance(1)?; + } + self.comments + .push(self.text()[comment_start_size..].to_string()); + } + + // Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. + // Multiple consecutive comments are preserved by appending them to the current comments list. + if Some(comment_start_line) == self.previous_token_line { + self.tokens + .last_mut() + .unwrap() + .append_comments(&mut self.comments); + self.previous_token_line = Some(self.line); + } + + Ok(true) + } + + fn scan_string(&mut self, start: &String) -> Result { + let (base, token_type, end) = if let Some(end) = self.settings.quotes.get(start) { + (None, self.token_types.string, end.clone()) + } else if self.settings.format_strings.contains_key(start) { + let (ref end, token_type) = self.settings.format_strings.get(start).unwrap(); + + if *token_type == self.token_types.hex_string { + (Some(16), *token_type, end.clone()) + } else if *token_type == self.token_types.bit_string { + (Some(2), *token_type, end.clone()) + } else if *token_type == self.token_types.heredoc_string { + self.advance(1)?; + let tag = if self.current_char.to_string() == *end { + String::from("") + } else { + self.extract_string(end, false)? + }; + (None, *token_type, format!("{}{}{}", start, tag, end)) + } else { + (None, *token_type, end.clone()) + } + } else { + return Ok(false); + }; + + self.advance(start.len() as isize)?; + let text = self.extract_string(&end, false)?; + + if let Some(b) = base { + if u64::from_str_radix(&text, b).is_err() { + return self.error_result(format!( + "Numeric string contains invalid characters from {}:{}", + self.line, self.start + )); + } + } + + self.add(token_type, Some(text))?; + Ok(true) + } + + fn scan_number(&mut self) -> Result<(), TokenizerError> { + if self.current_char == '0' { + let peek_char = self.peek_char.to_ascii_uppercase(); + if peek_char == 'B' { + if self.settings.has_bit_strings { + self.scan_bits()?; + } else { + self.add(self.token_types.number, None)?; + } + return Ok(()); + } else if peek_char == 'X' { + if self.settings.has_hex_strings { + self.scan_hex()?; + } else { + self.add(self.token_types.number, None)?; + } + return Ok(()); + } + } + + let mut decimal = false; + let mut scientific = 0; + + loop { + if self.peek_char.is_digit(10) { + self.advance(1)?; + } else if self.peek_char == '.' && !decimal { + let after = self.peek(1)?; + if after.is_digit(10) || !after.is_alphabetic() { + decimal = true; + self.advance(1)?; + } else { + return self.add(self.token_types.var, None); + } + } else if (self.peek_char == '-' || self.peek_char == '+') && scientific == 1 { + scientific += 1; + self.advance(1)?; + } else if self.peek_char.to_ascii_uppercase() == 'E' && scientific == 0 { + scientific += 1; + self.advance(1)?; + } else if self.peek_char.is_alphabetic() || self.peek_char == '_' { + let number_text = self.text(); + let mut literal = String::from(""); + + while !self.peek_char.is_whitespace() + && !self.is_end + && !self.settings.single_tokens.contains_key(&self.peek_char) + { + literal.push(self.peek_char); + self.advance(1)?; + } + + let token_type = self + .settings + .keywords + .get( + self.settings + .numeric_literals + .get(&literal.to_uppercase()) + .unwrap_or(&String::from("")), + ) + .map(|x| *x); + + if let Some(unwrapped_token_type) = token_type { + self.add(self.token_types.number, Some(number_text))?; + self.add(self.token_types.dcolon, Some("::".to_string()))?; + self.add(unwrapped_token_type, Some(literal))?; + } else if self.dialect_settings.identifiers_can_start_with_digit { + self.add(self.token_types.var, None)?; + } else { + self.advance(-(literal.chars().count() as isize))?; + self.add(self.token_types.number, Some(number_text))?; + } + return Ok(()); + } else { + return self.add(self.token_types.number, None); + } + } + } + + fn scan_bits(&mut self) -> Result<(), TokenizerError> { + self.scan_radix_string(2, self.token_types.bit_string) + } + + fn scan_hex(&mut self) -> Result<(), TokenizerError> { + self.scan_radix_string(16, self.token_types.hex_string) + } + + fn scan_radix_string( + &mut self, + radix: u32, + radix_token_type: TokenType, + ) -> Result<(), TokenizerError> { + self.advance(1)?; + let value = self.extract_value()?[2..].to_string(); + match u64::from_str_radix(&value, radix) { + Ok(_) => self.add(radix_token_type, Some(value)), + Err(_) => self.add(self.token_types.identifier, None), + } + } + + fn scan_var(&mut self) -> Result<(), TokenizerError> { + loop { + let peek_char = if !self.peek_char.is_whitespace() { + self.peek_char + } else { + '\0' + }; + if peek_char != '\0' + && (self.settings.var_single_tokens.contains(&peek_char) + || !self.settings.single_tokens.contains_key(&peek_char)) + { + self.advance(1)?; + } else { + break; + } + } + + let token_type = + if self.tokens.last().map(|t| t.token_type) == Some(self.token_types.parameter) { + self.token_types.var + } else { + self.settings + .keywords + .get(&self.text().to_uppercase()) + .map(|x| *x) + .unwrap_or(self.token_types.var) + }; + self.add(token_type, None) + } + + fn scan_identifier(&mut self, identifier_end: &str) -> Result<(), TokenizerError> { + self.advance(1)?; + let text = self.extract_string(identifier_end, true)?; + self.add(self.token_types.identifier, Some(text)) + } + + fn extract_string( + &mut self, + delimiter: &str, + use_identifier_escapes: bool, + ) -> Result { + let mut text = String::from(""); + + loop { + let escapes = if use_identifier_escapes { + &self.settings.identifier_escapes + } else { + &self.settings.string_escapes + }; + + let peek_char_str = self.peek_char.to_string(); + if escapes.contains(&self.current_char) + && (peek_char_str == delimiter || escapes.contains(&self.peek_char)) + && (self.current_char == self.peek_char + || !self + .settings + .quotes + .contains_key(&self.current_char.to_string())) + { + if peek_char_str == delimiter { + text.push(self.peek_char); + } else { + text.push(self.current_char); + text.push(self.peek_char); + } + if self.current + 1 < self.size { + self.advance(2)?; + } else { + return self.error_result(format!( + "Missing {} from {}:{}", + delimiter, self.line, self.current + )); + } + } else { + if self.chars(delimiter.len()) == delimiter { + if delimiter.len() > 1 { + self.advance((delimiter.len() - 1) as isize)?; + } + break; + } + if self.is_end { + return self.error_result(format!( + "Missing {} from {}:{}", + delimiter, self.line, self.current + )); + } + + if !self.dialect_settings.escape_sequences.is_empty() + && !self.peek_char.is_whitespace() + && self.settings.string_escapes.contains(&self.current_char) + { + let sequence_key = format!("{}{}", self.current_char, self.peek_char); + if let Some(escaped_sequence) = + self.dialect_settings.escape_sequences.get(&sequence_key) + { + self.advance(2)?; + text.push_str(escaped_sequence); + continue; + } + } + + let current = self.current - 1; + self.advance(1)?; + text.push_str( + &self.sql[current..self.current - 1] + .iter() + .collect::(), + ); + } + } + Ok(text) + } + + fn extract_value(&mut self) -> Result { + loop { + if !self.peek_char.is_whitespace() + && !self.is_end + && !self.settings.single_tokens.contains_key(&self.peek_char) + { + self.advance(1)?; + } else { + break; + } + } + Ok(self.text()) + } + + fn error(&self, message: String) -> TokenizerError { + let start = max((self.current as isize) - 50, 0); + let end = min(self.current + 50, self.size - 1); + let context = self.sql[start as usize..end].iter().collect::(); + TokenizerError { message, context } + } + + fn error_result(&self, message: String) -> Result { + Err(self.error(message)) + } +} diff --git a/sqlglotrs/src/trie.rs b/sqlglotrs/src/trie.rs new file mode 100644 index 0000000..8e6f20c --- /dev/null +++ b/sqlglotrs/src/trie.rs @@ -0,0 +1,68 @@ +use std::collections::HashMap; + +#[derive(Debug)] +pub struct TrieNode { + is_word: bool, + children: HashMap, +} + +#[derive(Debug)] +pub enum TrieResult { + Failed, + Prefix, + Exists, +} + +impl TrieNode { + pub fn contains(&self, key: &str) -> (TrieResult, &TrieNode) { + if key.is_empty() { + return (TrieResult::Failed, self); + } + + let mut current = self; + for c in key.chars() { + match current.children.get(&c) { + Some(node) => current = node, + None => return (TrieResult::Failed, current), + } + } + + if current.is_word { + (TrieResult::Exists, current) + } else { + (TrieResult::Prefix, current) + } + } +} + +#[derive(Debug)] +pub struct Trie { + pub root: TrieNode, +} + +impl Trie { + pub fn new() -> Self { + Trie { + root: TrieNode { + is_word: false, + children: HashMap::new(), + }, + } + } + + pub fn add<'a, I>(&mut self, keys: I) + where + I: Iterator, + { + for key in keys { + let mut current = &mut self.root; + for c in key.chars() { + current = current.children.entry(c).or_insert(TrieNode { + is_word: false, + children: HashMap::new(), + }); + } + current.is_word = true; + } + } +} -- cgit v1.2.3