From f1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Tue, 19 Dec 2023 12:01:55 +0100 Subject: Merging upstream version 20.3.0. Signed-off-by: Daniel Baumann --- sqlglotrs/Cargo.lock | 280 +++++++++++++++++++ sqlglotrs/Cargo.toml | 11 + sqlglotrs/pyproject.toml | 16 ++ sqlglotrs/src/lib.rs | 86 ++++++ sqlglotrs/src/settings.rs | 164 +++++++++++ sqlglotrs/src/tokenizer.rs | 670 +++++++++++++++++++++++++++++++++++++++++++++ sqlglotrs/src/trie.rs | 68 +++++ 7 files changed, 1295 insertions(+) create mode 100644 sqlglotrs/Cargo.lock create mode 100644 sqlglotrs/Cargo.toml create mode 100644 sqlglotrs/pyproject.toml create mode 100644 sqlglotrs/src/lib.rs create mode 100644 sqlglotrs/src/settings.rs create mode 100644 sqlglotrs/src/tokenizer.rs create mode 100644 sqlglotrs/src/trie.rs (limited to 'sqlglotrs') diff --git a/sqlglotrs/Cargo.lock b/sqlglotrs/Cargo.lock new file mode 100644 index 0000000..cd9a9ef --- /dev/null +++ b/sqlglotrs/Cargo.lock @@ -0,0 +1,280 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "indoc" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" + +[[package]] +name = "libc" +version = "0.2.150" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" + +[[package]] +name = "lock_api" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + +[[package]] +name = "proc-macro2" +version = "1.0.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04e8453b658fe480c3e70c8ed4e3d3ec33eb74988bd186561b0cc66b85c3bc4b" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a96fe70b176a89cff78f2fa7b3c930081e163d5379b4dcdf993e3ae29ca662e5" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "214929900fd25e6604661ed9cf349727c8920d47deff196c4e28165a6ef2a96b" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dac53072f717aa1bfa4db832b39de8c875b7c7af4f4a6fe93cdbf9264cf8383b" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7774b5a8282bd4f25f803b1f0d945120be959a36c72e08e7cd031c792fdfd424" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "smallvec" +version = "1.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" + +[[package]] +name = "sqlglotrs" +version = "0.1.0" +dependencies = [ + "pyo3", +] + +[[package]] +name = "syn" +version = "2.0.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c8b28c477cc3bf0e7966561e3460130e1255f7a1cf71931075f1c5e7a7e269" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c39fd04924ca3a864207c66fc2cd7d22d7c016007f9ce846cbb9326331930a" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" diff --git a/sqlglotrs/Cargo.toml b/sqlglotrs/Cargo.toml new file mode 100644 index 0000000..ece4a88 --- /dev/null +++ b/sqlglotrs/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "sqlglotrs" +version = "0.1.0" +edition = "2021" + +[lib] +name = "sqlglotrs" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = "0.20.0" diff --git a/sqlglotrs/pyproject.toml b/sqlglotrs/pyproject.toml new file mode 100644 index 0000000..867cdcc --- /dev/null +++ b/sqlglotrs/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["maturin>=1.4,<2.0"] +build-backend = "maturin" + +[project] +name = "sqlglotrs" +requires-python = ">=3.7" +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dynamic = ["version"] + +[tool.maturin] +features = ["pyo3/extension-module"] diff --git a/sqlglotrs/src/lib.rs b/sqlglotrs/src/lib.rs new file mode 100644 index 0000000..c962887 --- /dev/null +++ b/sqlglotrs/src/lib.rs @@ -0,0 +1,86 @@ +use pyo3::prelude::*; +use pyo3::types::{PyList, PyNone, PyString}; + +mod settings; +mod tokenizer; +mod trie; + +pub use self::settings::{ + TokenType, TokenTypeSettings, TokenizerDialectSettings, TokenizerSettings, +}; +pub use self::tokenizer::Tokenizer; + +#[derive(Debug)] +#[pyclass] +pub struct Token { + #[pyo3(get, name = "token_type_index")] + pub token_type: TokenType, + #[pyo3(get, set, name = "token_type")] + pub token_type_py: PyObject, + #[pyo3(get)] + pub text: Py, + #[pyo3(get)] + pub line: usize, + #[pyo3(get)] + pub col: usize, + #[pyo3(get)] + pub start: usize, + #[pyo3(get)] + pub end: usize, + #[pyo3(get)] + pub comments: Py, +} + +impl Token { + pub fn new( + token_type: TokenType, + text: String, + line: usize, + col: usize, + start: usize, + end: usize, + comments: Vec, + ) -> Token { + Python::with_gil(|py| Token { + token_type, + token_type_py: PyNone::get(py).into(), + text: PyString::new(py, &text).into(), + line, + col, + start, + end, + comments: PyList::new(py, &comments).into(), + }) + } + + pub fn append_comments(&self, comments: &mut Vec) { + Python::with_gil(|py| { + let pylist = self.comments.as_ref(py); + for comment in comments.iter() { + if let Err(_) = pylist.append(comment) { + panic!("Failed to append comments to the Python list"); + } + } + }); + // Simulate `Vec::append`. + let _ = std::mem::replace(comments, Vec::new()); + } +} + +#[pymethods] +impl Token { + #[pyo3(name = "__repr__")] + fn python_repr(&self) -> PyResult { + Ok(format!("{:?}", self)) + } +} + +#[pymodule] +fn sqlglotrs(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) +} diff --git a/sqlglotrs/src/settings.rs b/sqlglotrs/src/settings.rs new file mode 100644 index 0000000..32575c6 --- /dev/null +++ b/sqlglotrs/src/settings.rs @@ -0,0 +1,164 @@ +use pyo3::prelude::*; +use std::collections::{HashMap, HashSet}; + +pub type TokenType = u16; + +#[derive(Clone, Debug)] +#[pyclass] +pub struct TokenTypeSettings { + pub bit_string: TokenType, + pub break_: TokenType, + pub dcolon: TokenType, + pub heredoc_string: TokenType, + pub hex_string: TokenType, + pub identifier: TokenType, + pub number: TokenType, + pub parameter: TokenType, + pub semicolon: TokenType, + pub string: TokenType, + pub var: TokenType, +} + +#[pymethods] +impl TokenTypeSettings { + #[new] + pub fn new( + bit_string: TokenType, + break_: TokenType, + dcolon: TokenType, + heredoc_string: TokenType, + hex_string: TokenType, + identifier: TokenType, + number: TokenType, + parameter: TokenType, + semicolon: TokenType, + string: TokenType, + var: TokenType, + ) -> Self { + TokenTypeSettings { + bit_string, + break_, + dcolon, + heredoc_string, + hex_string, + identifier, + number, + parameter, + semicolon, + string, + var, + } + } +} + +#[derive(Clone, Debug)] +#[pyclass] +pub struct TokenizerSettings { + pub white_space: HashMap, + pub single_tokens: HashMap, + pub keywords: HashMap, + pub numeric_literals: HashMap, + pub identifiers: HashMap, + pub identifier_escapes: HashSet, + pub string_escapes: HashSet, + pub quotes: HashMap, + pub format_strings: HashMap, + pub has_bit_strings: bool, + pub has_hex_strings: bool, + pub comments: HashMap>, + pub var_single_tokens: HashSet, + pub commands: HashSet, + pub command_prefix_tokens: HashSet, +} + +#[pymethods] +impl TokenizerSettings { + #[new] + pub fn new( + white_space: HashMap, + single_tokens: HashMap, + keywords: HashMap, + numeric_literals: HashMap, + identifiers: HashMap, + identifier_escapes: HashSet, + string_escapes: HashSet, + quotes: HashMap, + format_strings: HashMap, + has_bit_strings: bool, + has_hex_strings: bool, + comments: HashMap>, + var_single_tokens: HashSet, + commands: HashSet, + command_prefix_tokens: HashSet, + ) -> Self { + let to_char = |v: &String| { + if v.len() == 1 { + v.chars().next().unwrap() + } else { + panic!("Invalid char: {}", v) + } + }; + + let white_space_native: HashMap = white_space + .into_iter() + .map(|(k, v)| (to_char(&k), v)) + .collect(); + + let single_tokens_native: HashMap = single_tokens + .into_iter() + .map(|(k, v)| (to_char(&k), v)) + .collect(); + + let identifiers_native: HashMap = identifiers + .iter() + .map(|(k, v)| (to_char(k), to_char(v))) + .collect(); + + let identifier_escapes_native: HashSet = + identifier_escapes.iter().map(&to_char).collect(); + + let string_escapes_native: HashSet = string_escapes.iter().map(&to_char).collect(); + + let var_single_tokens_native: HashSet = + var_single_tokens.iter().map(&to_char).collect(); + + TokenizerSettings { + white_space: white_space_native, + single_tokens: single_tokens_native, + keywords, + numeric_literals, + identifiers: identifiers_native, + identifier_escapes: identifier_escapes_native, + string_escapes: string_escapes_native, + quotes, + format_strings, + has_bit_strings, + has_hex_strings, + comments, + var_single_tokens: var_single_tokens_native, + commands, + command_prefix_tokens, + } + } +} + +#[derive(Clone, Debug)] +#[pyclass] +pub struct TokenizerDialectSettings { + pub escape_sequences: HashMap, + pub identifiers_can_start_with_digit: bool, +} + +#[pymethods] +impl TokenizerDialectSettings { + #[new] + pub fn new( + escape_sequences: HashMap, + identifiers_can_start_with_digit: bool, + ) -> Self { + TokenizerDialectSettings { + escape_sequences, + identifiers_can_start_with_digit, + } + } +} diff --git a/sqlglotrs/src/tokenizer.rs b/sqlglotrs/src/tokenizer.rs new file mode 100644 index 0000000..920a5b5 --- /dev/null +++ b/sqlglotrs/src/tokenizer.rs @@ -0,0 +1,670 @@ +use crate::trie::{Trie, TrieResult}; +use crate::{Token, TokenType, TokenTypeSettings, TokenizerDialectSettings, TokenizerSettings}; +use pyo3::exceptions::PyException; +use pyo3::prelude::*; +use std::cmp::{max, min}; + +#[derive(Debug)] +pub struct TokenizerError { + message: String, + context: String, +} + +#[derive(Debug)] +#[pyclass] +pub struct Tokenizer { + settings: TokenizerSettings, + token_types: TokenTypeSettings, + keyword_trie: Trie, +} + +#[pymethods] +impl Tokenizer { + #[new] + pub fn new(settings: TokenizerSettings, token_types: TokenTypeSettings) -> Tokenizer { + let mut keyword_trie = Trie::new(); + let single_token_strs: Vec = settings + .single_tokens + .keys() + .map(|s| s.to_string()) + .collect(); + let trie_filter = + |key: &&String| key.contains(" ") || single_token_strs.iter().any(|t| key.contains(t)); + + keyword_trie.add(settings.keywords.keys().filter(trie_filter)); + keyword_trie.add(settings.comments.keys().filter(trie_filter)); + keyword_trie.add(settings.quotes.keys().filter(trie_filter)); + keyword_trie.add(settings.format_strings.keys().filter(trie_filter)); + + Tokenizer { + settings, + token_types, + keyword_trie, + } + } + + pub fn tokenize( + &self, + sql: &str, + dialect_settings: &TokenizerDialectSettings, + ) -> Result, PyErr> { + let mut state = TokenizerState::new( + sql, + &self.settings, + &self.token_types, + dialect_settings, + &self.keyword_trie, + ); + state.tokenize().map_err(|e| { + PyException::new_err(format!("Error tokenizing '{}': {}", e.context, e.message)) + }) + } +} + +#[derive(Debug)] +struct TokenizerState<'a> { + sql: Vec, + size: usize, + tokens: Vec, + start: usize, + current: usize, + line: usize, + column: usize, + comments: Vec, + is_end: bool, + current_char: char, + peek_char: char, + previous_token_line: Option, + keyword_trie: &'a Trie, + settings: &'a TokenizerSettings, + dialect_settings: &'a TokenizerDialectSettings, + token_types: &'a TokenTypeSettings, +} + +impl<'a> TokenizerState<'a> { + fn new( + sql: &str, + settings: &'a TokenizerSettings, + token_types: &'a TokenTypeSettings, + dialect_settings: &'a TokenizerDialectSettings, + keyword_trie: &'a Trie, + ) -> TokenizerState<'a> { + let sql_vec = sql.chars().collect::>(); + let sql_vec_len = sql_vec.len(); + TokenizerState { + sql: sql_vec, + size: sql_vec_len, + tokens: Vec::new(), + start: 0, + current: 0, + line: 1, + column: 0, + comments: Vec::new(), + is_end: false, + current_char: '\0', + peek_char: '\0', + previous_token_line: None, + keyword_trie, + settings, + dialect_settings, + token_types, + } + } + + fn tokenize(&mut self) -> Result, TokenizerError> { + self.scan(None)?; + Ok(std::mem::replace(&mut self.tokens, Vec::new())) + } + + fn scan(&mut self, until_peek_char: Option) -> Result<(), TokenizerError> { + while self.size > 0 && !self.is_end { + self.start = self.current; + self.advance(1)?; + + if self.current_char == '\0' { + break; + } + + if !self.settings.white_space.contains_key(&self.current_char) { + if self.current_char.is_digit(10) { + self.scan_number()?; + } else if let Some(identifier_end) = + self.settings.identifiers.get(&self.current_char) + { + self.scan_identifier(&identifier_end.to_string())?; + } else { + self.scan_keyword()?; + } + } + + if let Some(c) = until_peek_char { + if self.peek_char == c { + break; + } + } + } + if !self.tokens.is_empty() && !self.comments.is_empty() { + self.tokens + .last_mut() + .unwrap() + .append_comments(&mut self.comments); + } + Ok(()) + } + + fn advance(&mut self, i: isize) -> Result<(), TokenizerError> { + let mut i = i; + if Some(&self.token_types.break_) == self.settings.white_space.get(&self.current_char) { + // Ensures we don't count an extra line if we get a \r\n line break sequence. + if self.current_char == '\r' && self.peek_char == '\n' { + i = 2; + self.start += 1; + } + + self.column = 1; + self.line += 1; + } else { + self.column = self.column.wrapping_add_signed(i); + } + + self.current = self.current.wrapping_add_signed(i); + self.is_end = self.current >= self.size; + self.current_char = self.char_at(self.current - 1)?; + self.peek_char = if self.is_end { + '\0' + } else { + self.char_at(self.current)? + }; + Ok(()) + } + + fn peek(&self, i: usize) -> Result { + let index = self.current + i; + if index < self.size { + self.char_at(index) + } else { + Ok('\0') + } + } + + fn chars(&self, size: usize) -> String { + let start = self.current - 1; + let end = start + size; + if end <= self.size { + self.sql[start..end].iter().collect() + } else { + String::from("") + } + } + + fn char_at(&self, index: usize) -> Result { + self.sql.get(index).map(|c| *c).ok_or_else(|| { + self.error(format!( + "Index {} is out of bound (size {})", + index, self.size + )) + }) + } + + fn text(&self) -> String { + self.sql[self.start..self.current].iter().collect() + } + + fn add(&mut self, token_type: TokenType, text: Option) -> Result<(), TokenizerError> { + self.previous_token_line = Some(self.line); + + if !self.comments.is_empty() + && !self.tokens.is_empty() + && token_type == self.token_types.semicolon + { + self.tokens + .last_mut() + .unwrap() + .append_comments(&mut self.comments); + } + + self.tokens.push(Token::new( + token_type, + text.unwrap_or(self.text()), + self.line, + self.column, + self.start, + self.current - 1, + std::mem::replace(&mut self.comments, Vec::new()), + )); + + // If we have either a semicolon or a begin token before the command's token, we'll parse + // whatever follows the command's token as a string. + if self.settings.commands.contains(&token_type) + && self.peek_char != ';' + && (self.tokens.len() == 1 + || self + .settings + .command_prefix_tokens + .contains(&self.tokens[self.tokens.len() - 2].token_type)) + { + let start = self.current; + let tokens_len = self.tokens.len(); + self.scan(Some(';'))?; + self.tokens.truncate(tokens_len); + let text = self.sql[start..self.current] + .iter() + .collect::() + .trim() + .to_string(); + if !text.is_empty() { + self.add(self.token_types.string, Some(text))?; + } + } + Ok(()) + } + + fn scan_keyword(&mut self) -> Result<(), TokenizerError> { + let mut size: usize = 0; + let mut word: Option = None; + let mut chars = self.text(); + let mut current_char = '\0'; + let mut prev_space = false; + let mut skip; + let mut is_single_token = chars.len() == 1 + && self + .settings + .single_tokens + .contains_key(&chars.chars().next().unwrap()); + + let (mut trie_result, mut trie_node) = + self.keyword_trie.root.contains(&chars.to_uppercase()); + + while !chars.is_empty() { + if let TrieResult::Failed = trie_result { + break; + } else if let TrieResult::Exists = trie_result { + word = Some(chars.clone()); + } + + let end = self.current + size; + size += 1; + + if end < self.size { + current_char = self.char_at(end)?; + is_single_token = + is_single_token || self.settings.single_tokens.contains_key(¤t_char); + let is_space = current_char.is_whitespace(); + + if !is_space || !prev_space { + if is_space { + current_char = ' '; + } + chars.push(current_char); + prev_space = is_space; + skip = false; + } else { + skip = true; + } + } else { + current_char = '\0'; + break; + } + + if skip { + trie_result = TrieResult::Prefix; + } else { + (trie_result, trie_node) = + trie_node.contains(¤t_char.to_uppercase().collect::()); + } + } + + if let Some(unwrapped_word) = word { + if self.scan_string(&unwrapped_word)? { + return Ok(()); + } + if self.scan_comment(&unwrapped_word)? { + return Ok(()); + } + if prev_space || is_single_token || current_char == '\0' { + self.advance((size - 1) as isize)?; + let normalized_word = unwrapped_word.to_uppercase(); + let keyword_token = + *self + .settings + .keywords + .get(&normalized_word) + .ok_or_else(|| { + self.error(format!("Unexpected keyword '{}'", &normalized_word)) + })?; + self.add(keyword_token, Some(unwrapped_word))?; + return Ok(()); + } + } + + match self.settings.single_tokens.get(&self.current_char) { + Some(token_type) => self.add(*token_type, Some(self.current_char.to_string())), + None => self.scan_var(), + } + } + + fn scan_comment(&mut self, comment_start: &str) -> Result { + if !self.settings.comments.contains_key(comment_start) { + return Ok(false); + } + + let comment_start_line = self.line; + let comment_start_size = comment_start.len(); + + if let Some(comment_end) = self.settings.comments.get(comment_start).unwrap() { + // Skip the comment's start delimiter. + self.advance(comment_start_size as isize)?; + + let comment_end_size = comment_end.len(); + + while !self.is_end && self.chars(comment_end_size) != *comment_end { + self.advance(1)?; + } + + let text = self.text(); + self.comments + .push(text[comment_start_size..text.len() - comment_end_size + 1].to_string()); + self.advance((comment_end_size - 1) as isize)?; + } else { + while !self.is_end + && self.settings.white_space.get(&self.peek_char) != Some(&self.token_types.break_) + { + self.advance(1)?; + } + self.comments + .push(self.text()[comment_start_size..].to_string()); + } + + // Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. + // Multiple consecutive comments are preserved by appending them to the current comments list. + if Some(comment_start_line) == self.previous_token_line { + self.tokens + .last_mut() + .unwrap() + .append_comments(&mut self.comments); + self.previous_token_line = Some(self.line); + } + + Ok(true) + } + + fn scan_string(&mut self, start: &String) -> Result { + let (base, token_type, end) = if let Some(end) = self.settings.quotes.get(start) { + (None, self.token_types.string, end.clone()) + } else if self.settings.format_strings.contains_key(start) { + let (ref end, token_type) = self.settings.format_strings.get(start).unwrap(); + + if *token_type == self.token_types.hex_string { + (Some(16), *token_type, end.clone()) + } else if *token_type == self.token_types.bit_string { + (Some(2), *token_type, end.clone()) + } else if *token_type == self.token_types.heredoc_string { + self.advance(1)?; + let tag = if self.current_char.to_string() == *end { + String::from("") + } else { + self.extract_string(end, false)? + }; + (None, *token_type, format!("{}{}{}", start, tag, end)) + } else { + (None, *token_type, end.clone()) + } + } else { + return Ok(false); + }; + + self.advance(start.len() as isize)?; + let text = self.extract_string(&end, false)?; + + if let Some(b) = base { + if u64::from_str_radix(&text, b).is_err() { + return self.error_result(format!( + "Numeric string contains invalid characters from {}:{}", + self.line, self.start + )); + } + } + + self.add(token_type, Some(text))?; + Ok(true) + } + + fn scan_number(&mut self) -> Result<(), TokenizerError> { + if self.current_char == '0' { + let peek_char = self.peek_char.to_ascii_uppercase(); + if peek_char == 'B' { + if self.settings.has_bit_strings { + self.scan_bits()?; + } else { + self.add(self.token_types.number, None)?; + } + return Ok(()); + } else if peek_char == 'X' { + if self.settings.has_hex_strings { + self.scan_hex()?; + } else { + self.add(self.token_types.number, None)?; + } + return Ok(()); + } + } + + let mut decimal = false; + let mut scientific = 0; + + loop { + if self.peek_char.is_digit(10) { + self.advance(1)?; + } else if self.peek_char == '.' && !decimal { + let after = self.peek(1)?; + if after.is_digit(10) || !after.is_alphabetic() { + decimal = true; + self.advance(1)?; + } else { + return self.add(self.token_types.var, None); + } + } else if (self.peek_char == '-' || self.peek_char == '+') && scientific == 1 { + scientific += 1; + self.advance(1)?; + } else if self.peek_char.to_ascii_uppercase() == 'E' && scientific == 0 { + scientific += 1; + self.advance(1)?; + } else if self.peek_char.is_alphabetic() || self.peek_char == '_' { + let number_text = self.text(); + let mut literal = String::from(""); + + while !self.peek_char.is_whitespace() + && !self.is_end + && !self.settings.single_tokens.contains_key(&self.peek_char) + { + literal.push(self.peek_char); + self.advance(1)?; + } + + let token_type = self + .settings + .keywords + .get( + self.settings + .numeric_literals + .get(&literal.to_uppercase()) + .unwrap_or(&String::from("")), + ) + .map(|x| *x); + + if let Some(unwrapped_token_type) = token_type { + self.add(self.token_types.number, Some(number_text))?; + self.add(self.token_types.dcolon, Some("::".to_string()))?; + self.add(unwrapped_token_type, Some(literal))?; + } else if self.dialect_settings.identifiers_can_start_with_digit { + self.add(self.token_types.var, None)?; + } else { + self.advance(-(literal.chars().count() as isize))?; + self.add(self.token_types.number, Some(number_text))?; + } + return Ok(()); + } else { + return self.add(self.token_types.number, None); + } + } + } + + fn scan_bits(&mut self) -> Result<(), TokenizerError> { + self.scan_radix_string(2, self.token_types.bit_string) + } + + fn scan_hex(&mut self) -> Result<(), TokenizerError> { + self.scan_radix_string(16, self.token_types.hex_string) + } + + fn scan_radix_string( + &mut self, + radix: u32, + radix_token_type: TokenType, + ) -> Result<(), TokenizerError> { + self.advance(1)?; + let value = self.extract_value()?[2..].to_string(); + match u64::from_str_radix(&value, radix) { + Ok(_) => self.add(radix_token_type, Some(value)), + Err(_) => self.add(self.token_types.identifier, None), + } + } + + fn scan_var(&mut self) -> Result<(), TokenizerError> { + loop { + let peek_char = if !self.peek_char.is_whitespace() { + self.peek_char + } else { + '\0' + }; + if peek_char != '\0' + && (self.settings.var_single_tokens.contains(&peek_char) + || !self.settings.single_tokens.contains_key(&peek_char)) + { + self.advance(1)?; + } else { + break; + } + } + + let token_type = + if self.tokens.last().map(|t| t.token_type) == Some(self.token_types.parameter) { + self.token_types.var + } else { + self.settings + .keywords + .get(&self.text().to_uppercase()) + .map(|x| *x) + .unwrap_or(self.token_types.var) + }; + self.add(token_type, None) + } + + fn scan_identifier(&mut self, identifier_end: &str) -> Result<(), TokenizerError> { + self.advance(1)?; + let text = self.extract_string(identifier_end, true)?; + self.add(self.token_types.identifier, Some(text)) + } + + fn extract_string( + &mut self, + delimiter: &str, + use_identifier_escapes: bool, + ) -> Result { + let mut text = String::from(""); + + loop { + let escapes = if use_identifier_escapes { + &self.settings.identifier_escapes + } else { + &self.settings.string_escapes + }; + + let peek_char_str = self.peek_char.to_string(); + if escapes.contains(&self.current_char) + && (peek_char_str == delimiter || escapes.contains(&self.peek_char)) + && (self.current_char == self.peek_char + || !self + .settings + .quotes + .contains_key(&self.current_char.to_string())) + { + if peek_char_str == delimiter { + text.push(self.peek_char); + } else { + text.push(self.current_char); + text.push(self.peek_char); + } + if self.current + 1 < self.size { + self.advance(2)?; + } else { + return self.error_result(format!( + "Missing {} from {}:{}", + delimiter, self.line, self.current + )); + } + } else { + if self.chars(delimiter.len()) == delimiter { + if delimiter.len() > 1 { + self.advance((delimiter.len() - 1) as isize)?; + } + break; + } + if self.is_end { + return self.error_result(format!( + "Missing {} from {}:{}", + delimiter, self.line, self.current + )); + } + + if !self.dialect_settings.escape_sequences.is_empty() + && !self.peek_char.is_whitespace() + && self.settings.string_escapes.contains(&self.current_char) + { + let sequence_key = format!("{}{}", self.current_char, self.peek_char); + if let Some(escaped_sequence) = + self.dialect_settings.escape_sequences.get(&sequence_key) + { + self.advance(2)?; + text.push_str(escaped_sequence); + continue; + } + } + + let current = self.current - 1; + self.advance(1)?; + text.push_str( + &self.sql[current..self.current - 1] + .iter() + .collect::(), + ); + } + } + Ok(text) + } + + fn extract_value(&mut self) -> Result { + loop { + if !self.peek_char.is_whitespace() + && !self.is_end + && !self.settings.single_tokens.contains_key(&self.peek_char) + { + self.advance(1)?; + } else { + break; + } + } + Ok(self.text()) + } + + fn error(&self, message: String) -> TokenizerError { + let start = max((self.current as isize) - 50, 0); + let end = min(self.current + 50, self.size - 1); + let context = self.sql[start as usize..end].iter().collect::(); + TokenizerError { message, context } + } + + fn error_result(&self, message: String) -> Result { + Err(self.error(message)) + } +} diff --git a/sqlglotrs/src/trie.rs b/sqlglotrs/src/trie.rs new file mode 100644 index 0000000..8e6f20c --- /dev/null +++ b/sqlglotrs/src/trie.rs @@ -0,0 +1,68 @@ +use std::collections::HashMap; + +#[derive(Debug)] +pub struct TrieNode { + is_word: bool, + children: HashMap, +} + +#[derive(Debug)] +pub enum TrieResult { + Failed, + Prefix, + Exists, +} + +impl TrieNode { + pub fn contains(&self, key: &str) -> (TrieResult, &TrieNode) { + if key.is_empty() { + return (TrieResult::Failed, self); + } + + let mut current = self; + for c in key.chars() { + match current.children.get(&c) { + Some(node) => current = node, + None => return (TrieResult::Failed, current), + } + } + + if current.is_word { + (TrieResult::Exists, current) + } else { + (TrieResult::Prefix, current) + } + } +} + +#[derive(Debug)] +pub struct Trie { + pub root: TrieNode, +} + +impl Trie { + pub fn new() -> Self { + Trie { + root: TrieNode { + is_word: false, + children: HashMap::new(), + }, + } + } + + pub fn add<'a, I>(&mut self, keys: I) + where + I: Iterator, + { + for key in keys { + let mut current = &mut self.root; + for c in key.chars() { + current = current.children.entry(c).or_insert(TrieNode { + is_word: false, + children: HashMap::new(), + }); + } + current.is_word = true; + } + } +} -- cgit v1.2.3