diff options
Diffstat (limited to 'sqlglot/optimizer/qualify_tables.py')
-rw-r--r-- | sqlglot/optimizer/qualify_tables.py | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 0e467d3..5d8e0d9 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -1,10 +1,11 @@ import itertools from sqlglot import alias, exp +from sqlglot.helper import csv_reader from sqlglot.optimizer.scope import traverse_scope -def qualify_tables(expression, db=None, catalog=None): +def qualify_tables(expression, db=None, catalog=None, schema=None): """ Rewrite sqlglot AST to have fully qualified tables. @@ -18,6 +19,7 @@ def qualify_tables(expression, db=None, catalog=None): expression (sqlglot.Expression): expression to qualify db (str): Database name catalog (str): Catalog name + schema: A schema to populate Returns: sqlglot.Expression: qualified expression """ @@ -41,7 +43,7 @@ def qualify_tables(expression, db=None, catalog=None): source.set("catalog", exp.to_identifier(catalog)) if not source.alias: - source.replace( + source = source.replace( alias( source.copy(), source.this if identifier else f"_q_{next(sequence)}", @@ -49,4 +51,12 @@ def qualify_tables(expression, db=None, catalog=None): ) ) + if schema and isinstance(source.this, exp.ReadCSV): + with csv_reader(source.this) as reader: + header = next(reader) + columns = next(reader) + schema.add_table( + source, {k: type(v).__name__ for k, v in zip(header, columns)} + ) + return expression |