summaryrefslogtreecommitdiffstats
path: root/sqlglotrs/src
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-12-19 11:01:55 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-12-19 11:01:55 +0000
commitf1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5 (patch)
tree5dce0fe2a11381761496eb973c20750f44db56d5 /sqlglotrs/src
parentReleasing debian version 20.1.0-1. (diff)
downloadsqlglot-f1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5.tar.xz
sqlglot-f1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5.zip
Merging upstream version 20.3.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglotrs/src')
-rw-r--r--sqlglotrs/src/lib.rs86
-rw-r--r--sqlglotrs/src/settings.rs164
-rw-r--r--sqlglotrs/src/tokenizer.rs670
-rw-r--r--sqlglotrs/src/trie.rs68
4 files changed, 988 insertions, 0 deletions
diff --git a/sqlglotrs/src/lib.rs b/sqlglotrs/src/lib.rs
new file mode 100644
index 0000000..c962887
--- /dev/null
+++ b/sqlglotrs/src/lib.rs
@@ -0,0 +1,86 @@
+use pyo3::prelude::*;
+use pyo3::types::{PyList, PyNone, PyString};
+
+mod settings;
+mod tokenizer;
+mod trie;
+
+pub use self::settings::{
+ TokenType, TokenTypeSettings, TokenizerDialectSettings, TokenizerSettings,
+};
+pub use self::tokenizer::Tokenizer;
+
+#[derive(Debug)]
+#[pyclass]
+pub struct Token {
+ #[pyo3(get, name = "token_type_index")]
+ pub token_type: TokenType,
+ #[pyo3(get, set, name = "token_type")]
+ pub token_type_py: PyObject,
+ #[pyo3(get)]
+ pub text: Py<PyString>,
+ #[pyo3(get)]
+ pub line: usize,
+ #[pyo3(get)]
+ pub col: usize,
+ #[pyo3(get)]
+ pub start: usize,
+ #[pyo3(get)]
+ pub end: usize,
+ #[pyo3(get)]
+ pub comments: Py<PyList>,
+}
+
+impl Token {
+ pub fn new(
+ token_type: TokenType,
+ text: String,
+ line: usize,
+ col: usize,
+ start: usize,
+ end: usize,
+ comments: Vec<String>,
+ ) -> Token {
+ Python::with_gil(|py| Token {
+ token_type,
+ token_type_py: PyNone::get(py).into(),
+ text: PyString::new(py, &text).into(),
+ line,
+ col,
+ start,
+ end,
+ comments: PyList::new(py, &comments).into(),
+ })
+ }
+
+ pub fn append_comments(&self, comments: &mut Vec<String>) {
+ Python::with_gil(|py| {
+ let pylist = self.comments.as_ref(py);
+ for comment in comments.iter() {
+ if let Err(_) = pylist.append(comment) {
+ panic!("Failed to append comments to the Python list");
+ }
+ }
+ });
+ // Simulate `Vec::append`.
+ let _ = std::mem::replace(comments, Vec::new());
+ }
+}
+
+#[pymethods]
+impl Token {
+ #[pyo3(name = "__repr__")]
+ fn python_repr(&self) -> PyResult<String> {
+ Ok(format!("{:?}", self))
+ }
+}
+
+#[pymodule]
+fn sqlglotrs(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
+ m.add_class::<Token>()?;
+ m.add_class::<TokenTypeSettings>()?;
+ m.add_class::<TokenizerSettings>()?;
+ m.add_class::<TokenizerDialectSettings>()?;
+ m.add_class::<Tokenizer>()?;
+ Ok(())
+}
diff --git a/sqlglotrs/src/settings.rs b/sqlglotrs/src/settings.rs
new file mode 100644
index 0000000..32575c6
--- /dev/null
+++ b/sqlglotrs/src/settings.rs
@@ -0,0 +1,164 @@
+use pyo3::prelude::*;
+use std::collections::{HashMap, HashSet};
+
+pub type TokenType = u16;
+
+#[derive(Clone, Debug)]
+#[pyclass]
+pub struct TokenTypeSettings {
+ pub bit_string: TokenType,
+ pub break_: TokenType,
+ pub dcolon: TokenType,
+ pub heredoc_string: TokenType,
+ pub hex_string: TokenType,
+ pub identifier: TokenType,
+ pub number: TokenType,
+ pub parameter: TokenType,
+ pub semicolon: TokenType,
+ pub string: TokenType,
+ pub var: TokenType,
+}
+
+#[pymethods]
+impl TokenTypeSettings {
+ #[new]
+ pub fn new(
+ bit_string: TokenType,
+ break_: TokenType,
+ dcolon: TokenType,
+ heredoc_string: TokenType,
+ hex_string: TokenType,
+ identifier: TokenType,
+ number: TokenType,
+ parameter: TokenType,
+ semicolon: TokenType,
+ string: TokenType,
+ var: TokenType,
+ ) -> Self {
+ TokenTypeSettings {
+ bit_string,
+ break_,
+ dcolon,
+ heredoc_string,
+ hex_string,
+ identifier,
+ number,
+ parameter,
+ semicolon,
+ string,
+ var,
+ }
+ }
+}
+
+#[derive(Clone, Debug)]
+#[pyclass]
+pub struct TokenizerSettings {
+ pub white_space: HashMap<char, TokenType>,
+ pub single_tokens: HashMap<char, TokenType>,
+ pub keywords: HashMap<String, TokenType>,
+ pub numeric_literals: HashMap<String, String>,
+ pub identifiers: HashMap<char, char>,
+ pub identifier_escapes: HashSet<char>,
+ pub string_escapes: HashSet<char>,
+ pub quotes: HashMap<String, String>,
+ pub format_strings: HashMap<String, (String, TokenType)>,
+ pub has_bit_strings: bool,
+ pub has_hex_strings: bool,
+ pub comments: HashMap<String, Option<String>>,
+ pub var_single_tokens: HashSet<char>,
+ pub commands: HashSet<TokenType>,
+ pub command_prefix_tokens: HashSet<TokenType>,
+}
+
+#[pymethods]
+impl TokenizerSettings {
+ #[new]
+ pub fn new(
+ white_space: HashMap<String, TokenType>,
+ single_tokens: HashMap<String, TokenType>,
+ keywords: HashMap<String, TokenType>,
+ numeric_literals: HashMap<String, String>,
+ identifiers: HashMap<String, String>,
+ identifier_escapes: HashSet<String>,
+ string_escapes: HashSet<String>,
+ quotes: HashMap<String, String>,
+ format_strings: HashMap<String, (String, TokenType)>,
+ has_bit_strings: bool,
+ has_hex_strings: bool,
+ comments: HashMap<String, Option<String>>,
+ var_single_tokens: HashSet<String>,
+ commands: HashSet<TokenType>,
+ command_prefix_tokens: HashSet<TokenType>,
+ ) -> Self {
+ let to_char = |v: &String| {
+ if v.len() == 1 {
+ v.chars().next().unwrap()
+ } else {
+ panic!("Invalid char: {}", v)
+ }
+ };
+
+ let white_space_native: HashMap<char, TokenType> = white_space
+ .into_iter()
+ .map(|(k, v)| (to_char(&k), v))
+ .collect();
+
+ let single_tokens_native: HashMap<char, TokenType> = single_tokens
+ .into_iter()
+ .map(|(k, v)| (to_char(&k), v))
+ .collect();
+
+ let identifiers_native: HashMap<char, char> = identifiers
+ .iter()
+ .map(|(k, v)| (to_char(k), to_char(v)))
+ .collect();
+
+ let identifier_escapes_native: HashSet<char> =
+ identifier_escapes.iter().map(&to_char).collect();
+
+ let string_escapes_native: HashSet<char> = string_escapes.iter().map(&to_char).collect();
+
+ let var_single_tokens_native: HashSet<char> =
+ var_single_tokens.iter().map(&to_char).collect();
+
+ TokenizerSettings {
+ white_space: white_space_native,
+ single_tokens: single_tokens_native,
+ keywords,
+ numeric_literals,
+ identifiers: identifiers_native,
+ identifier_escapes: identifier_escapes_native,
+ string_escapes: string_escapes_native,
+ quotes,
+ format_strings,
+ has_bit_strings,
+ has_hex_strings,
+ comments,
+ var_single_tokens: var_single_tokens_native,
+ commands,
+ command_prefix_tokens,
+ }
+ }
+}
+
+#[derive(Clone, Debug)]
+#[pyclass]
+pub struct TokenizerDialectSettings {
+ pub escape_sequences: HashMap<String, String>,
+ pub identifiers_can_start_with_digit: bool,
+}
+
+#[pymethods]
+impl TokenizerDialectSettings {
+ #[new]
+ pub fn new(
+ escape_sequences: HashMap<String, String>,
+ identifiers_can_start_with_digit: bool,
+ ) -> Self {
+ TokenizerDialectSettings {
+ escape_sequences,
+ identifiers_can_start_with_digit,
+ }
+ }
+}
diff --git a/sqlglotrs/src/tokenizer.rs b/sqlglotrs/src/tokenizer.rs
new file mode 100644
index 0000000..920a5b5
--- /dev/null
+++ b/sqlglotrs/src/tokenizer.rs
@@ -0,0 +1,670 @@
+use crate::trie::{Trie, TrieResult};
+use crate::{Token, TokenType, TokenTypeSettings, TokenizerDialectSettings, TokenizerSettings};
+use pyo3::exceptions::PyException;
+use pyo3::prelude::*;
+use std::cmp::{max, min};
+
+#[derive(Debug)]
+pub struct TokenizerError {
+ message: String,
+ context: String,
+}
+
+#[derive(Debug)]
+#[pyclass]
+pub struct Tokenizer {
+ settings: TokenizerSettings,
+ token_types: TokenTypeSettings,
+ keyword_trie: Trie,
+}
+
+#[pymethods]
+impl Tokenizer {
+ #[new]
+ pub fn new(settings: TokenizerSettings, token_types: TokenTypeSettings) -> Tokenizer {
+ let mut keyword_trie = Trie::new();
+ let single_token_strs: Vec<String> = settings
+ .single_tokens
+ .keys()
+ .map(|s| s.to_string())
+ .collect();
+ let trie_filter =
+ |key: &&String| key.contains(" ") || single_token_strs.iter().any(|t| key.contains(t));
+
+ keyword_trie.add(settings.keywords.keys().filter(trie_filter));
+ keyword_trie.add(settings.comments.keys().filter(trie_filter));
+ keyword_trie.add(settings.quotes.keys().filter(trie_filter));
+ keyword_trie.add(settings.format_strings.keys().filter(trie_filter));
+
+ Tokenizer {
+ settings,
+ token_types,
+ keyword_trie,
+ }
+ }
+
+ pub fn tokenize(
+ &self,
+ sql: &str,
+ dialect_settings: &TokenizerDialectSettings,
+ ) -> Result<Vec<Token>, PyErr> {
+ let mut state = TokenizerState::new(
+ sql,
+ &self.settings,
+ &self.token_types,
+ dialect_settings,
+ &self.keyword_trie,
+ );
+ state.tokenize().map_err(|e| {
+ PyException::new_err(format!("Error tokenizing '{}': {}", e.context, e.message))
+ })
+ }
+}
+
+#[derive(Debug)]
+struct TokenizerState<'a> {
+ sql: Vec<char>,
+ size: usize,
+ tokens: Vec<Token>,
+ start: usize,
+ current: usize,
+ line: usize,
+ column: usize,
+ comments: Vec<String>,
+ is_end: bool,
+ current_char: char,
+ peek_char: char,
+ previous_token_line: Option<usize>,
+ keyword_trie: &'a Trie,
+ settings: &'a TokenizerSettings,
+ dialect_settings: &'a TokenizerDialectSettings,
+ token_types: &'a TokenTypeSettings,
+}
+
+impl<'a> TokenizerState<'a> {
+ fn new(
+ sql: &str,
+ settings: &'a TokenizerSettings,
+ token_types: &'a TokenTypeSettings,
+ dialect_settings: &'a TokenizerDialectSettings,
+ keyword_trie: &'a Trie,
+ ) -> TokenizerState<'a> {
+ let sql_vec = sql.chars().collect::<Vec<char>>();
+ let sql_vec_len = sql_vec.len();
+ TokenizerState {
+ sql: sql_vec,
+ size: sql_vec_len,
+ tokens: Vec::new(),
+ start: 0,
+ current: 0,
+ line: 1,
+ column: 0,
+ comments: Vec::new(),
+ is_end: false,
+ current_char: '\0',
+ peek_char: '\0',
+ previous_token_line: None,
+ keyword_trie,
+ settings,
+ dialect_settings,
+ token_types,
+ }
+ }
+
+ fn tokenize(&mut self) -> Result<Vec<Token>, TokenizerError> {
+ self.scan(None)?;
+ Ok(std::mem::replace(&mut self.tokens, Vec::new()))
+ }
+
+ fn scan(&mut self, until_peek_char: Option<char>) -> Result<(), TokenizerError> {
+ while self.size > 0 && !self.is_end {
+ self.start = self.current;
+ self.advance(1)?;
+
+ if self.current_char == '\0' {
+ break;
+ }
+
+ if !self.settings.white_space.contains_key(&self.current_char) {
+ if self.current_char.is_digit(10) {
+ self.scan_number()?;
+ } else if let Some(identifier_end) =
+ self.settings.identifiers.get(&self.current_char)
+ {
+ self.scan_identifier(&identifier_end.to_string())?;
+ } else {
+ self.scan_keyword()?;
+ }
+ }
+
+ if let Some(c) = until_peek_char {
+ if self.peek_char == c {
+ break;
+ }
+ }
+ }
+ if !self.tokens.is_empty() && !self.comments.is_empty() {
+ self.tokens
+ .last_mut()
+ .unwrap()
+ .append_comments(&mut self.comments);
+ }
+ Ok(())
+ }
+
+ fn advance(&mut self, i: isize) -> Result<(), TokenizerError> {
+ let mut i = i;
+ if Some(&self.token_types.break_) == self.settings.white_space.get(&self.current_char) {
+ // Ensures we don't count an extra line if we get a \r\n line break sequence.
+ if self.current_char == '\r' && self.peek_char == '\n' {
+ i = 2;
+ self.start += 1;
+ }
+
+ self.column = 1;
+ self.line += 1;
+ } else {
+ self.column = self.column.wrapping_add_signed(i);
+ }
+
+ self.current = self.current.wrapping_add_signed(i);
+ self.is_end = self.current >= self.size;
+ self.current_char = self.char_at(self.current - 1)?;
+ self.peek_char = if self.is_end {
+ '\0'
+ } else {
+ self.char_at(self.current)?
+ };
+ Ok(())
+ }
+
+ fn peek(&self, i: usize) -> Result<char, TokenizerError> {
+ let index = self.current + i;
+ if index < self.size {
+ self.char_at(index)
+ } else {
+ Ok('\0')
+ }
+ }
+
+ fn chars(&self, size: usize) -> String {
+ let start = self.current - 1;
+ let end = start + size;
+ if end <= self.size {
+ self.sql[start..end].iter().collect()
+ } else {
+ String::from("")
+ }
+ }
+
+ fn char_at(&self, index: usize) -> Result<char, TokenizerError> {
+ self.sql.get(index).map(|c| *c).ok_or_else(|| {
+ self.error(format!(
+ "Index {} is out of bound (size {})",
+ index, self.size
+ ))
+ })
+ }
+
+ fn text(&self) -> String {
+ self.sql[self.start..self.current].iter().collect()
+ }
+
+ fn add(&mut self, token_type: TokenType, text: Option<String>) -> Result<(), TokenizerError> {
+ self.previous_token_line = Some(self.line);
+
+ if !self.comments.is_empty()
+ && !self.tokens.is_empty()
+ && token_type == self.token_types.semicolon
+ {
+ self.tokens
+ .last_mut()
+ .unwrap()
+ .append_comments(&mut self.comments);
+ }
+
+ self.tokens.push(Token::new(
+ token_type,
+ text.unwrap_or(self.text()),
+ self.line,
+ self.column,
+ self.start,
+ self.current - 1,
+ std::mem::replace(&mut self.comments, Vec::new()),
+ ));
+
+ // If we have either a semicolon or a begin token before the command's token, we'll parse
+ // whatever follows the command's token as a string.
+ if self.settings.commands.contains(&token_type)
+ && self.peek_char != ';'
+ && (self.tokens.len() == 1
+ || self
+ .settings
+ .command_prefix_tokens
+ .contains(&self.tokens[self.tokens.len() - 2].token_type))
+ {
+ let start = self.current;
+ let tokens_len = self.tokens.len();
+ self.scan(Some(';'))?;
+ self.tokens.truncate(tokens_len);
+ let text = self.sql[start..self.current]
+ .iter()
+ .collect::<String>()
+ .trim()
+ .to_string();
+ if !text.is_empty() {
+ self.add(self.token_types.string, Some(text))?;
+ }
+ }
+ Ok(())
+ }
+
+ fn scan_keyword(&mut self) -> Result<(), TokenizerError> {
+ let mut size: usize = 0;
+ let mut word: Option<String> = None;
+ let mut chars = self.text();
+ let mut current_char = '\0';
+ let mut prev_space = false;
+ let mut skip;
+ let mut is_single_token = chars.len() == 1
+ && self
+ .settings
+ .single_tokens
+ .contains_key(&chars.chars().next().unwrap());
+
+ let (mut trie_result, mut trie_node) =
+ self.keyword_trie.root.contains(&chars.to_uppercase());
+
+ while !chars.is_empty() {
+ if let TrieResult::Failed = trie_result {
+ break;
+ } else if let TrieResult::Exists = trie_result {
+ word = Some(chars.clone());
+ }
+
+ let end = self.current + size;
+ size += 1;
+
+ if end < self.size {
+ current_char = self.char_at(end)?;
+ is_single_token =
+ is_single_token || self.settings.single_tokens.contains_key(&current_char);
+ let is_space = current_char.is_whitespace();
+
+ if !is_space || !prev_space {
+ if is_space {
+ current_char = ' ';
+ }
+ chars.push(current_char);
+ prev_space = is_space;
+ skip = false;
+ } else {
+ skip = true;
+ }
+ } else {
+ current_char = '\0';
+ break;
+ }
+
+ if skip {
+ trie_result = TrieResult::Prefix;
+ } else {
+ (trie_result, trie_node) =
+ trie_node.contains(&current_char.to_uppercase().collect::<String>());
+ }
+ }
+
+ if let Some(unwrapped_word) = word {
+ if self.scan_string(&unwrapped_word)? {
+ return Ok(());
+ }
+ if self.scan_comment(&unwrapped_word)? {
+ return Ok(());
+ }
+ if prev_space || is_single_token || current_char == '\0' {
+ self.advance((size - 1) as isize)?;
+ let normalized_word = unwrapped_word.to_uppercase();
+ let keyword_token =
+ *self
+ .settings
+ .keywords
+ .get(&normalized_word)
+ .ok_or_else(|| {
+ self.error(format!("Unexpected keyword '{}'", &normalized_word))
+ })?;
+ self.add(keyword_token, Some(unwrapped_word))?;
+ return Ok(());
+ }
+ }
+
+ match self.settings.single_tokens.get(&self.current_char) {
+ Some(token_type) => self.add(*token_type, Some(self.current_char.to_string())),
+ None => self.scan_var(),
+ }
+ }
+
+ fn scan_comment(&mut self, comment_start: &str) -> Result<bool, TokenizerError> {
+ if !self.settings.comments.contains_key(comment_start) {
+ return Ok(false);
+ }
+
+ let comment_start_line = self.line;
+ let comment_start_size = comment_start.len();
+
+ if let Some(comment_end) = self.settings.comments.get(comment_start).unwrap() {
+ // Skip the comment's start delimiter.
+ self.advance(comment_start_size as isize)?;
+
+ let comment_end_size = comment_end.len();
+
+ while !self.is_end && self.chars(comment_end_size) != *comment_end {
+ self.advance(1)?;
+ }
+
+ let text = self.text();
+ self.comments
+ .push(text[comment_start_size..text.len() - comment_end_size + 1].to_string());
+ self.advance((comment_end_size - 1) as isize)?;
+ } else {
+ while !self.is_end
+ && self.settings.white_space.get(&self.peek_char) != Some(&self.token_types.break_)
+ {
+ self.advance(1)?;
+ }
+ self.comments
+ .push(self.text()[comment_start_size..].to_string());
+ }
+
+ // Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
+ // Multiple consecutive comments are preserved by appending them to the current comments list.
+ if Some(comment_start_line) == self.previous_token_line {
+ self.tokens
+ .last_mut()
+ .unwrap()
+ .append_comments(&mut self.comments);
+ self.previous_token_line = Some(self.line);
+ }
+
+ Ok(true)
+ }
+
+ fn scan_string(&mut self, start: &String) -> Result<bool, TokenizerError> {
+ let (base, token_type, end) = if let Some(end) = self.settings.quotes.get(start) {
+ (None, self.token_types.string, end.clone())
+ } else if self.settings.format_strings.contains_key(start) {
+ let (ref end, token_type) = self.settings.format_strings.get(start).unwrap();
+
+ if *token_type == self.token_types.hex_string {
+ (Some(16), *token_type, end.clone())
+ } else if *token_type == self.token_types.bit_string {
+ (Some(2), *token_type, end.clone())
+ } else if *token_type == self.token_types.heredoc_string {
+ self.advance(1)?;
+ let tag = if self.current_char.to_string() == *end {
+ String::from("")
+ } else {
+ self.extract_string(end, false)?
+ };
+ (None, *token_type, format!("{}{}{}", start, tag, end))
+ } else {
+ (None, *token_type, end.clone())
+ }
+ } else {
+ return Ok(false);
+ };
+
+ self.advance(start.len() as isize)?;
+ let text = self.extract_string(&end, false)?;
+
+ if let Some(b) = base {
+ if u64::from_str_radix(&text, b).is_err() {
+ return self.error_result(format!(
+ "Numeric string contains invalid characters from {}:{}",
+ self.line, self.start
+ ));
+ }
+ }
+
+ self.add(token_type, Some(text))?;
+ Ok(true)
+ }
+
+ fn scan_number(&mut self) -> Result<(), TokenizerError> {
+ if self.current_char == '0' {
+ let peek_char = self.peek_char.to_ascii_uppercase();
+ if peek_char == 'B' {
+ if self.settings.has_bit_strings {
+ self.scan_bits()?;
+ } else {
+ self.add(self.token_types.number, None)?;
+ }
+ return Ok(());
+ } else if peek_char == 'X' {
+ if self.settings.has_hex_strings {
+ self.scan_hex()?;
+ } else {
+ self.add(self.token_types.number, None)?;
+ }
+ return Ok(());
+ }
+ }
+
+ let mut decimal = false;
+ let mut scientific = 0;
+
+ loop {
+ if self.peek_char.is_digit(10) {
+ self.advance(1)?;
+ } else if self.peek_char == '.' && !decimal {
+ let after = self.peek(1)?;
+ if after.is_digit(10) || !after.is_alphabetic() {
+ decimal = true;
+ self.advance(1)?;
+ } else {
+ return self.add(self.token_types.var, None);
+ }
+ } else if (self.peek_char == '-' || self.peek_char == '+') && scientific == 1 {
+ scientific += 1;
+ self.advance(1)?;
+ } else if self.peek_char.to_ascii_uppercase() == 'E' && scientific == 0 {
+ scientific += 1;
+ self.advance(1)?;
+ } else if self.peek_char.is_alphabetic() || self.peek_char == '_' {
+ let number_text = self.text();
+ let mut literal = String::from("");
+
+ while !self.peek_char.is_whitespace()
+ && !self.is_end
+ && !self.settings.single_tokens.contains_key(&self.peek_char)
+ {
+ literal.push(self.peek_char);
+ self.advance(1)?;
+ }
+
+ let token_type = self
+ .settings
+ .keywords
+ .get(
+ self.settings
+ .numeric_literals
+ .get(&literal.to_uppercase())
+ .unwrap_or(&String::from("")),
+ )
+ .map(|x| *x);
+
+ if let Some(unwrapped_token_type) = token_type {
+ self.add(self.token_types.number, Some(number_text))?;
+ self.add(self.token_types.dcolon, Some("::".to_string()))?;
+ self.add(unwrapped_token_type, Some(literal))?;
+ } else if self.dialect_settings.identifiers_can_start_with_digit {
+ self.add(self.token_types.var, None)?;
+ } else {
+ self.advance(-(literal.chars().count() as isize))?;
+ self.add(self.token_types.number, Some(number_text))?;
+ }
+ return Ok(());
+ } else {
+ return self.add(self.token_types.number, None);
+ }
+ }
+ }
+
+ fn scan_bits(&mut self) -> Result<(), TokenizerError> {
+ self.scan_radix_string(2, self.token_types.bit_string)
+ }
+
+ fn scan_hex(&mut self) -> Result<(), TokenizerError> {
+ self.scan_radix_string(16, self.token_types.hex_string)
+ }
+
+ fn scan_radix_string(
+ &mut self,
+ radix: u32,
+ radix_token_type: TokenType,
+ ) -> Result<(), TokenizerError> {
+ self.advance(1)?;
+ let value = self.extract_value()?[2..].to_string();
+ match u64::from_str_radix(&value, radix) {
+ Ok(_) => self.add(radix_token_type, Some(value)),
+ Err(_) => self.add(self.token_types.identifier, None),
+ }
+ }
+
+ fn scan_var(&mut self) -> Result<(), TokenizerError> {
+ loop {
+ let peek_char = if !self.peek_char.is_whitespace() {
+ self.peek_char
+ } else {
+ '\0'
+ };
+ if peek_char != '\0'
+ && (self.settings.var_single_tokens.contains(&peek_char)
+ || !self.settings.single_tokens.contains_key(&peek_char))
+ {
+ self.advance(1)?;
+ } else {
+ break;
+ }
+ }
+
+ let token_type =
+ if self.tokens.last().map(|t| t.token_type) == Some(self.token_types.parameter) {
+ self.token_types.var
+ } else {
+ self.settings
+ .keywords
+ .get(&self.text().to_uppercase())
+ .map(|x| *x)
+ .unwrap_or(self.token_types.var)
+ };
+ self.add(token_type, None)
+ }
+
+ fn scan_identifier(&mut self, identifier_end: &str) -> Result<(), TokenizerError> {
+ self.advance(1)?;
+ let text = self.extract_string(identifier_end, true)?;
+ self.add(self.token_types.identifier, Some(text))
+ }
+
+ fn extract_string(
+ &mut self,
+ delimiter: &str,
+ use_identifier_escapes: bool,
+ ) -> Result<String, TokenizerError> {
+ let mut text = String::from("");
+
+ loop {
+ let escapes = if use_identifier_escapes {
+ &self.settings.identifier_escapes
+ } else {
+ &self.settings.string_escapes
+ };
+
+ let peek_char_str = self.peek_char.to_string();
+ if escapes.contains(&self.current_char)
+ && (peek_char_str == delimiter || escapes.contains(&self.peek_char))
+ && (self.current_char == self.peek_char
+ || !self
+ .settings
+ .quotes
+ .contains_key(&self.current_char.to_string()))
+ {
+ if peek_char_str == delimiter {
+ text.push(self.peek_char);
+ } else {
+ text.push(self.current_char);
+ text.push(self.peek_char);
+ }
+ if self.current + 1 < self.size {
+ self.advance(2)?;
+ } else {
+ return self.error_result(format!(
+ "Missing {} from {}:{}",
+ delimiter, self.line, self.current
+ ));
+ }
+ } else {
+ if self.chars(delimiter.len()) == delimiter {
+ if delimiter.len() > 1 {
+ self.advance((delimiter.len() - 1) as isize)?;
+ }
+ break;
+ }
+ if self.is_end {
+ return self.error_result(format!(
+ "Missing {} from {}:{}",
+ delimiter, self.line, self.current
+ ));
+ }
+
+ if !self.dialect_settings.escape_sequences.is_empty()
+ && !self.peek_char.is_whitespace()
+ && self.settings.string_escapes.contains(&self.current_char)
+ {
+ let sequence_key = format!("{}{}", self.current_char, self.peek_char);
+ if let Some(escaped_sequence) =
+ self.dialect_settings.escape_sequences.get(&sequence_key)
+ {
+ self.advance(2)?;
+ text.push_str(escaped_sequence);
+ continue;
+ }
+ }
+
+ let current = self.current - 1;
+ self.advance(1)?;
+ text.push_str(
+ &self.sql[current..self.current - 1]
+ .iter()
+ .collect::<String>(),
+ );
+ }
+ }
+ Ok(text)
+ }
+
+ fn extract_value(&mut self) -> Result<String, TokenizerError> {
+ loop {
+ if !self.peek_char.is_whitespace()
+ && !self.is_end
+ && !self.settings.single_tokens.contains_key(&self.peek_char)
+ {
+ self.advance(1)?;
+ } else {
+ break;
+ }
+ }
+ Ok(self.text())
+ }
+
+ fn error(&self, message: String) -> TokenizerError {
+ let start = max((self.current as isize) - 50, 0);
+ let end = min(self.current + 50, self.size - 1);
+ let context = self.sql[start as usize..end].iter().collect::<String>();
+ TokenizerError { message, context }
+ }
+
+ fn error_result<T>(&self, message: String) -> Result<T, TokenizerError> {
+ Err(self.error(message))
+ }
+}
diff --git a/sqlglotrs/src/trie.rs b/sqlglotrs/src/trie.rs
new file mode 100644
index 0000000..8e6f20c
--- /dev/null
+++ b/sqlglotrs/src/trie.rs
@@ -0,0 +1,68 @@
+use std::collections::HashMap;
+
+#[derive(Debug)]
+pub struct TrieNode {
+ is_word: bool,
+ children: HashMap<char, TrieNode>,
+}
+
+#[derive(Debug)]
+pub enum TrieResult {
+ Failed,
+ Prefix,
+ Exists,
+}
+
+impl TrieNode {
+ pub fn contains(&self, key: &str) -> (TrieResult, &TrieNode) {
+ if key.is_empty() {
+ return (TrieResult::Failed, self);
+ }
+
+ let mut current = self;
+ for c in key.chars() {
+ match current.children.get(&c) {
+ Some(node) => current = node,
+ None => return (TrieResult::Failed, current),
+ }
+ }
+
+ if current.is_word {
+ (TrieResult::Exists, current)
+ } else {
+ (TrieResult::Prefix, current)
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Trie {
+ pub root: TrieNode,
+}
+
+impl Trie {
+ pub fn new() -> Self {
+ Trie {
+ root: TrieNode {
+ is_word: false,
+ children: HashMap::new(),
+ },
+ }
+ }
+
+ pub fn add<'a, I>(&mut self, keys: I)
+ where
+ I: Iterator<Item = &'a String>,
+ {
+ for key in keys {
+ let mut current = &mut self.root;
+ for c in key.chars() {
+ current = current.children.entry(c).or_insert(TrieNode {
+ is_word: false,
+ children: HashMap::new(),
+ });
+ }
+ current.is_word = true;
+ }
+ }
+}