diff options
Diffstat (limited to 'sqlglot/dialects/sqlite.py')
-rw-r--r-- | sqlglot/dialects/sqlite.py | 43 |
1 files changed, 40 insertions, 3 deletions
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 4437f82..f2efe32 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -22,6 +22,40 @@ def _date_add_sql(self, expression): return self.func("DATE", expression.this, modifier) +def _transform_create(expression: exp.Expression) -> exp.Expression: + """Move primary key to a column and enforce auto_increment on primary keys.""" + schema = expression.this + + if isinstance(expression, exp.Create) and isinstance(schema, exp.Schema): + defs = {} + primary_key = None + + for e in schema.expressions: + if isinstance(e, exp.ColumnDef): + defs[e.name] = e + elif isinstance(e, exp.PrimaryKey): + primary_key = e + + if primary_key and len(primary_key.expressions) == 1: + column = defs[primary_key.expressions[0].name] + column.append( + "constraints", exp.ColumnConstraint(kind=exp.PrimaryKeyColumnConstraint()) + ) + schema.expressions.remove(primary_key) + else: + for column in defs.values(): + auto_increment = None + for constraint in column.constraints.copy(): + if isinstance(constraint.kind, exp.PrimaryKeyColumnConstraint): + break + if isinstance(constraint.kind, exp.AutoIncrementColumnConstraint): + auto_increment = constraint + if auto_increment: + column.constraints.remove(auto_increment) + + return expression + + class SQLite(Dialect): class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] @@ -65,8 +99,8 @@ class SQLite(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore - **transforms.ELIMINATE_QUALIFY, # type: ignore exp.CountIf: count_if_to_sum, + exp.Create: transforms.preprocess([_transform_create]), exp.CurrentDate: lambda *_: "CURRENT_DATE", exp.CurrentTime: lambda *_: "CURRENT_TIME", exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", @@ -80,14 +114,17 @@ class SQLite(Dialect): exp.Levenshtein: rename_func("EDITDIST3"), exp.LogicalOr: rename_func("MAX"), exp.LogicalAnd: rename_func("MIN"), + exp.Select: transforms.preprocess( + [transforms.eliminate_distinct_on, transforms.eliminate_qualify] + ), exp.TableSample: no_tablesample_sql, exp.TimeStrToTime: lambda self, e: self.sql(e, "this"), exp.TryCast: no_trycast_sql, } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore - exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + k: exp.Properties.Location.UNSUPPORTED + for k, v in generator.Generator.PROPERTIES_LOCATION.items() } LIMIT_FETCH = "LIMIT" |