summaryrefslogtreecommitdiffstats
path: root/psycopg/psycopg/sql.py
blob: 099a01cc429c7c1b9262f0daa864397dcc28043e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
"""
SQL composition utility module
"""

# Copyright (C) 2020 The Psycopg Team

import codecs
import string
from abc import ABC, abstractmethod
from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union

from .pq import Escaping
from .abc import AdaptContext
from .adapt import Transformer, PyFormat
from ._compat import LiteralString
from ._encodings import conn_encoding


def quote(obj: Any, context: Optional[AdaptContext] = None) -> str:
    """
    Adapt a Python object to a quoted SQL string.

    Use this function only if you absolutely want to convert a Python string to
    an SQL quoted literal to use e.g. to generate batch SQL and you won't have
    a connection available when you will need to use it.

    This function is relatively inefficient, because it doesn't cache the
    adaptation rules. If you pass a `!context` you can adapt the adaptation
    rules used, otherwise only global rules are used.

    """
    return Literal(obj).as_string(context)


class Composable(ABC):
    """
    Abstract base class for objects that can be used to compose an SQL string.

    `!Composable` objects can be passed directly to
    `~psycopg.Cursor.execute()`, `~psycopg.Cursor.executemany()`,
    `~psycopg.Cursor.copy()` in place of the query string.

    `!Composable` objects can be joined using the ``+`` operator: the result
    will be a `Composed` instance containing the objects joined. The operator
    ``*`` is also supported with an integer argument: the result is a
    `!Composed` instance containing the left argument repeated as many times as
    requested.
    """

    def __init__(self, obj: Any):
        self._obj = obj

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self._obj!r})"

    @abstractmethod
    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
        """
        Return the value of the object as bytes.

        :param context: the context to evaluate the object into.
        :type context: `connection` or `cursor`

        The method is automatically invoked by `~psycopg.Cursor.execute()`,
        `~psycopg.Cursor.executemany()`, `~psycopg.Cursor.copy()` if a
        `!Composable` is passed instead of the query string.

        """
        raise NotImplementedError

    def as_string(self, context: Optional[AdaptContext]) -> str:
        """
        Return the value of the object as string.

        :param context: the context to evaluate the string into.
        :type context: `connection` or `cursor`

        """
        conn = context.connection if context else None
        enc = conn_encoding(conn)
        b = self.as_bytes(context)
        if isinstance(b, bytes):
            return b.decode(enc)
        else:
            # buffer object
            return codecs.lookup(enc).decode(b)[0]

    def __add__(self, other: "Composable") -> "Composed":
        if isinstance(other, Composed):
            return Composed([self]) + other
        if isinstance(other, Composable):
            return Composed([self]) + Composed([other])
        else:
            return NotImplemented

    def __mul__(self, n: int) -> "Composed":
        return Composed([self] * n)

    def __eq__(self, other: Any) -> bool:
        return type(self) is type(other) and self._obj == other._obj

    def __ne__(self, other: Any) -> bool:
        return not self.__eq__(other)


class Composed(Composable):
    """
    A `Composable` object made of a sequence of `!Composable`.

    The object is usually created using `!Composable` operators and methods.
    However it is possible to create a `!Composed` directly specifying a
    sequence of objects as arguments: if they are not `!Composable` they will
    be wrapped in a `Literal`.

    Example::

        >>> comp = sql.Composed(
        ...     [sql.SQL("INSERT INTO "), sql.Identifier("table")])
        >>> print(comp.as_string(conn))
        INSERT INTO "table"

    `!Composed` objects are iterable (so they can be used in `SQL.join` for
    instance).
    """

    _obj: List[Composable]

    def __init__(self, seq: Sequence[Any]):
        seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq]
        super().__init__(seq)

    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
        return b"".join(obj.as_bytes(context) for obj in self._obj)

    def __iter__(self) -> Iterator[Composable]:
        return iter(self._obj)

    def __add__(self, other: Composable) -> "Composed":
        if isinstance(other, Composed):
            return Composed(self._obj + other._obj)
        if isinstance(other, Composable):
            return Composed(self._obj + [other])
        else:
            return NotImplemented

    def join(self, joiner: Union["SQL", LiteralString]) -> "Composed":
        """
        Return a new `!Composed` interposing the `!joiner` with the `!Composed` items.

        The `!joiner` must be a `SQL` or a string which will be interpreted as
        an `SQL`.

        Example::

            >>> fields = sql.Identifier('foo') + sql.Identifier('bar')  # a Composed
            >>> print(fields.join(', ').as_string(conn))
            "foo", "bar"

        """
        if isinstance(joiner, str):
            joiner = SQL(joiner)
        elif not isinstance(joiner, SQL):
            raise TypeError(
                "Composed.join() argument must be strings or SQL,"
                f" got {joiner!r} instead"
            )

        return joiner.join(self._obj)


class SQL(Composable):
    """
    A `Composable` representing a snippet of SQL statement.

    `!SQL` exposes `join()` and `format()` methods useful to create a template
    where to merge variable parts of a query (for instance field or table
    names).

    The `!obj` string doesn't undergo any form of escaping, so it is not
    suitable to represent variable identifiers or values: you should only use
    it to pass constant strings representing templates or snippets of SQL
    statements; use other objects such as `Identifier` or `Literal` to
    represent variable parts.

    Example::

        >>> query = sql.SQL("SELECT {0} FROM {1}").format(
        ...    sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]),
        ...    sql.Identifier('table'))
        >>> print(query.as_string(conn))
        SELECT "foo", "bar" FROM "table"
    """

    _obj: LiteralString
    _formatter = string.Formatter()

    def __init__(self, obj: LiteralString):
        super().__init__(obj)
        if not isinstance(obj, str):
            raise TypeError(f"SQL values must be strings, got {obj!r} instead")

    def as_string(self, context: Optional[AdaptContext]) -> str:
        return self._obj

    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
        enc = "utf-8"
        if context:
            enc = conn_encoding(context.connection)
        return self._obj.encode(enc)

    def format(self, *args: Any, **kwargs: Any) -> Composed:
        """
        Merge `Composable` objects into a template.

        :param args: parameters to replace to numbered (``{0}``, ``{1}``) or
            auto-numbered (``{}``) placeholders
        :param kwargs: parameters to replace to named (``{name}``) placeholders
        :return: the union of the `!SQL` string with placeholders replaced
        :rtype: `Composed`

        The method is similar to the Python `str.format()` method: the string
        template supports auto-numbered (``{}``), numbered (``{0}``,
        ``{1}``...), and named placeholders (``{name}``), with positional
        arguments replacing the numbered placeholders and keywords replacing
        the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``)
        are not supported.

        If a `!Composable` objects is passed to the template it will be merged
        according to its `as_string()` method. If any other Python object is
        passed, it will be wrapped in a `Literal` object and so escaped
        according to SQL rules.

        Example::

            >>> print(sql.SQL("SELECT * FROM {} WHERE {} = %s")
            ...     .format(sql.Identifier('people'), sql.Identifier('id'))
            ...     .as_string(conn))
            SELECT * FROM "people" WHERE "id" = %s

            >>> print(sql.SQL("SELECT * FROM {tbl} WHERE name = {name}")
            ...     .format(tbl=sql.Identifier('people'), name="O'Rourke"))
            ...     .as_string(conn))
            SELECT * FROM "people" WHERE name = 'O''Rourke'

        """
        rv: List[Composable] = []
        autonum: Optional[int] = 0
        # TODO: this is probably not the right way to whitelist pre
        # pyre complains. Will wait for mypy to complain too to fix.
        pre: LiteralString
        for pre, name, spec, conv in self._formatter.parse(self._obj):
            if spec:
                raise ValueError("no format specification supported by SQL")
            if conv:
                raise ValueError("no format conversion supported by SQL")
            if pre:
                rv.append(SQL(pre))

            if name is None:
                continue

            if name.isdigit():
                if autonum:
                    raise ValueError(
                        "cannot switch from automatic field numbering to manual"
                    )
                rv.append(args[int(name)])
                autonum = None

            elif not name:
                if autonum is None:
                    raise ValueError(
                        "cannot switch from manual field numbering to automatic"
                    )
                rv.append(args[autonum])
                autonum += 1

            else:
                rv.append(kwargs[name])

        return Composed(rv)

    def join(self, seq: Iterable[Composable]) -> Composed:
        """
        Join a sequence of `Composable`.

        :param seq: the elements to join.
        :type seq: iterable of `!Composable`

        Use the `!SQL` object's string to separate the elements in `!seq`.
        Note that `Composed` objects are iterable too, so they can be used as
        argument for this method.

        Example::

            >>> snip = sql.SQL(', ').join(
            ...     sql.Identifier(n) for n in ['foo', 'bar', 'baz'])
            >>> print(snip.as_string(conn))
            "foo", "bar", "baz"
        """
        rv = []
        it = iter(seq)
        try:
            rv.append(next(it))
        except StopIteration:
            pass
        else:
            for i in it:
                rv.append(self)
                rv.append(i)

        return Composed(rv)


class Identifier(Composable):
    """
    A `Composable` representing an SQL identifier or a dot-separated sequence.

    Identifiers usually represent names of database objects, such as tables or
    fields. PostgreSQL identifiers follow `different rules`__ than SQL string
    literals for escaping (e.g. they use double quotes instead of single).

    .. __: https://www.postgresql.org/docs/current/sql-syntax-lexical.html# \
        SQL-SYNTAX-IDENTIFIERS

    Example::

        >>> t1 = sql.Identifier("foo")
        >>> t2 = sql.Identifier("ba'r")
        >>> t3 = sql.Identifier('ba"z')
        >>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn))
        "foo", "ba'r", "ba""z"

    Multiple strings can be passed to the object to represent a qualified name,
    i.e. a dot-separated sequence of identifiers.

    Example::

        >>> query = sql.SQL("SELECT {} FROM {}").format(
        ...     sql.Identifier("table", "field"),
        ...     sql.Identifier("schema", "table"))
        >>> print(query.as_string(conn))
        SELECT "table"."field" FROM "schema"."table"

    """

    _obj: Sequence[str]

    def __init__(self, *strings: str):
        # init super() now to make the __repr__ not explode in case of error
        super().__init__(strings)

        if not strings:
            raise TypeError("Identifier cannot be empty")

        for s in strings:
            if not isinstance(s, str):
                raise TypeError(
                    f"SQL identifier parts must be strings, got {s!r} instead"
                )

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"

    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
        conn = context.connection if context else None
        if not conn:
            raise ValueError("a connection is necessary for Identifier")
        esc = Escaping(conn.pgconn)
        enc = conn_encoding(conn)
        escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
        return b".".join(escs)


class Literal(Composable):
    """
    A `Composable` representing an SQL value to include in a query.

    Usually you will want to include placeholders in the query and pass values
    as `~cursor.execute()` arguments. If however you really really need to
    include a literal value in the query you can use this object.

    The string returned by `!as_string()` follows the normal :ref:`adaptation
    rules <types-adaptation>` for Python objects.

    Example::

        >>> s1 = sql.Literal("fo'o")
        >>> s2 = sql.Literal(42)
        >>> s3 = sql.Literal(date(2000, 1, 1))
        >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn))
        'fo''o', 42, '2000-01-01'::date

    """

    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
        tx = Transformer.from_context(context)
        return tx.as_literal(self._obj)


class Placeholder(Composable):
    """A `Composable` representing a placeholder for query parameters.

    If the name is specified, generate a named placeholder (e.g. ``%(name)s``,
    ``%(name)b``), otherwise generate a positional placeholder (e.g. ``%s``,
    ``%b``).

    The object is useful to generate SQL queries with a variable number of
    arguments.

    Examples::

        >>> names = ['foo', 'bar', 'baz']

        >>> q1 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
        ...     sql.SQL(', ').join(map(sql.Identifier, names)),
        ...     sql.SQL(', ').join(sql.Placeholder() * len(names)))
        >>> print(q1.as_string(conn))
        INSERT INTO my_table ("foo", "bar", "baz") VALUES (%s, %s, %s)

        >>> q2 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
        ...     sql.SQL(', ').join(map(sql.Identifier, names)),
        ...     sql.SQL(', ').join(map(sql.Placeholder, names)))
        >>> print(q2.as_string(conn))
        INSERT INTO my_table ("foo", "bar", "baz") VALUES (%(foo)s, %(bar)s, %(baz)s)

    """

    def __init__(self, name: str = "", format: Union[str, PyFormat] = PyFormat.AUTO):
        super().__init__(name)
        if not isinstance(name, str):
            raise TypeError(f"expected string as name, got {name!r}")

        if ")" in name:
            raise ValueError(f"invalid name: {name!r}")

        if type(format) is str:
            format = PyFormat(format)
        if not isinstance(format, PyFormat):
            raise TypeError(
                f"expected PyFormat as format, got {type(format).__name__!r}"
            )

        self._format: PyFormat = format

    def __repr__(self) -> str:
        parts = []
        if self._obj:
            parts.append(repr(self._obj))
        if self._format is not PyFormat.AUTO:
            parts.append(f"format={self._format.name}")

        return f"{self.__class__.__name__}({', '.join(parts)})"

    def as_string(self, context: Optional[AdaptContext]) -> str:
        code = self._format.value
        return f"%({self._obj}){code}" if self._obj else f"%{code}"

    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
        conn = context.connection if context else None
        enc = conn_encoding(conn)
        return self.as_string(context).encode(enc)


# Literals
NULL = SQL("NULL")
DEFAULT = SQL("DEFAULT")