Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  21
  22
  23if t.TYPE_CHECKING:
  24    from sqlglot._typing import B, E, F
  25
  26logger = logging.getLogger("sqlglot")
  27
  28
  29class Dialects(str, Enum):
  30    """Dialects supported by SQLGLot."""
  31
  32    DIALECT = ""
  33
  34    ATHENA = "athena"
  35    BIGQUERY = "bigquery"
  36    CLICKHOUSE = "clickhouse"
  37    DATABRICKS = "databricks"
  38    DORIS = "doris"
  39    DRILL = "drill"
  40    DUCKDB = "duckdb"
  41    HIVE = "hive"
  42    MYSQL = "mysql"
  43    ORACLE = "oracle"
  44    POSTGRES = "postgres"
  45    PRESTO = "presto"
  46    PRQL = "prql"
  47    REDSHIFT = "redshift"
  48    SNOWFLAKE = "snowflake"
  49    SPARK = "spark"
  50    SPARK2 = "spark2"
  51    SQLITE = "sqlite"
  52    STARROCKS = "starrocks"
  53    TABLEAU = "tableau"
  54    TERADATA = "teradata"
  55    TRINO = "trino"
  56    TSQL = "tsql"
  57
  58
  59class NormalizationStrategy(str, AutoName):
  60    """Specifies the strategy according to which identifiers should be normalized."""
  61
  62    LOWERCASE = auto()
  63    """Unquoted identifiers are lowercased."""
  64
  65    UPPERCASE = auto()
  66    """Unquoted identifiers are uppercased."""
  67
  68    CASE_SENSITIVE = auto()
  69    """Always case-sensitive, regardless of quotes."""
  70
  71    CASE_INSENSITIVE = auto()
  72    """Always case-insensitive, regardless of quotes."""
  73
  74
  75class _Dialect(type):
  76    classes: t.Dict[str, t.Type[Dialect]] = {}
  77
  78    def __eq__(cls, other: t.Any) -> bool:
  79        if cls is other:
  80            return True
  81        if isinstance(other, str):
  82            return cls is cls.get(other)
  83        if isinstance(other, Dialect):
  84            return cls is type(other)
  85
  86        return False
  87
  88    def __hash__(cls) -> int:
  89        return hash(cls.__name__.lower())
  90
  91    @classmethod
  92    def __getitem__(cls, key: str) -> t.Type[Dialect]:
  93        return cls.classes[key]
  94
  95    @classmethod
  96    def get(
  97        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
  98    ) -> t.Optional[t.Type[Dialect]]:
  99        return cls.classes.get(key, default)
 100
 101    def __new__(cls, clsname, bases, attrs):
 102        klass = super().__new__(cls, clsname, bases, attrs)
 103        enum = Dialects.__members__.get(clsname.upper())
 104        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 105
 106        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 107        klass.FORMAT_TRIE = (
 108            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 109        )
 110        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 111        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 112
 113        base = seq_get(bases, 0)
 114        base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
 115        base_parser = (getattr(base, "parser_class", Parser),)
 116        base_generator = (getattr(base, "generator_class", Generator),)
 117
 118        klass.tokenizer_class = klass.__dict__.get(
 119            "Tokenizer", type("Tokenizer", base_tokenizer, {})
 120        )
 121        klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
 122        klass.generator_class = klass.__dict__.get(
 123            "Generator", type("Generator", base_generator, {})
 124        )
 125
 126        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 127        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 128            klass.tokenizer_class._IDENTIFIERS.items()
 129        )[0]
 130
 131        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 132            return next(
 133                (
 134                    (s, e)
 135                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 136                    if t == token_type
 137                ),
 138                (None, None),
 139            )
 140
 141        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 142        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 143        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 144        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 145
 146        if "\\" in klass.tokenizer_class.STRING_ESCAPES:
 147            klass.UNESCAPED_SEQUENCES = {
 148                "\\a": "\a",
 149                "\\b": "\b",
 150                "\\f": "\f",
 151                "\\n": "\n",
 152                "\\r": "\r",
 153                "\\t": "\t",
 154                "\\v": "\v",
 155                "\\\\": "\\",
 156                **klass.UNESCAPED_SEQUENCES,
 157            }
 158
 159        klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
 160
 161        if enum not in ("", "bigquery"):
 162            klass.generator_class.SELECT_KINDS = ()
 163
 164        if enum not in ("", "athena", "presto", "trino"):
 165            klass.generator_class.TRY_SUPPORTED = False
 166
 167        if enum not in ("", "databricks", "hive", "spark", "spark2"):
 168            modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
 169            for modifier in ("cluster", "distribute", "sort"):
 170                modifier_transforms.pop(modifier, None)
 171
 172            klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
 173
 174        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 175            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 176                TokenType.ANTI,
 177                TokenType.SEMI,
 178            }
 179
 180        return klass
 181
 182
 183class Dialect(metaclass=_Dialect):
 184    INDEX_OFFSET = 0
 185    """The base index offset for arrays."""
 186
 187    WEEK_OFFSET = 0
 188    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 189
 190    UNNEST_COLUMN_ONLY = False
 191    """Whether `UNNEST` table aliases are treated as column aliases."""
 192
 193    ALIAS_POST_TABLESAMPLE = False
 194    """Whether the table alias comes after tablesample."""
 195
 196    TABLESAMPLE_SIZE_IS_PERCENT = False
 197    """Whether a size in the table sample clause represents percentage."""
 198
 199    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 200    """Specifies the strategy according to which identifiers should be normalized."""
 201
 202    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 203    """Whether an unquoted identifier can start with a digit."""
 204
 205    DPIPE_IS_STRING_CONCAT = True
 206    """Whether the DPIPE token (`||`) is a string concatenation operator."""
 207
 208    STRICT_STRING_CONCAT = False
 209    """Whether `CONCAT`'s arguments must be strings."""
 210
 211    SUPPORTS_USER_DEFINED_TYPES = True
 212    """Whether user-defined data types are supported."""
 213
 214    SUPPORTS_SEMI_ANTI_JOIN = True
 215    """Whether `SEMI` or `ANTI` joins are supported."""
 216
 217    NORMALIZE_FUNCTIONS: bool | str = "upper"
 218    """
 219    Determines how function names are going to be normalized.
 220    Possible values:
 221        "upper" or True: Convert names to uppercase.
 222        "lower": Convert names to lowercase.
 223        False: Disables function name normalization.
 224    """
 225
 226    LOG_BASE_FIRST: t.Optional[bool] = True
 227    """
 228    Whether the base comes first in the `LOG` function.
 229    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
 230    """
 231
 232    NULL_ORDERING = "nulls_are_small"
 233    """
 234    Default `NULL` ordering method to use if not explicitly set.
 235    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 236    """
 237
 238    TYPED_DIVISION = False
 239    """
 240    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 241    False means `a / b` is always float division.
 242    True means `a / b` is integer division if both `a` and `b` are integers.
 243    """
 244
 245    SAFE_DIVISION = False
 246    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 247
 248    CONCAT_COALESCE = False
 249    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 250
 251    HEX_LOWERCASE = False
 252    """Whether the `HEX` function returns a lowercase hexadecimal string."""
 253
 254    DATE_FORMAT = "'%Y-%m-%d'"
 255    DATEINT_FORMAT = "'%Y%m%d'"
 256    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 257
 258    TIME_MAPPING: t.Dict[str, str] = {}
 259    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
 260
 261    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 262    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 263    FORMAT_MAPPING: t.Dict[str, str] = {}
 264    """
 265    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 266    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 267    """
 268
 269    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
 270    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
 271
 272    PSEUDOCOLUMNS: t.Set[str] = set()
 273    """
 274    Columns that are auto-generated by the engine corresponding to this dialect.
 275    For example, such columns may be excluded from `SELECT *` queries.
 276    """
 277
 278    PREFER_CTE_ALIAS_COLUMN = False
 279    """
 280    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 281    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 282    any projection aliases in the subquery.
 283
 284    For example,
 285        WITH y(c) AS (
 286            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 287        ) SELECT c FROM y;
 288
 289        will be rewritten as
 290
 291        WITH y(c) AS (
 292            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 293        ) SELECT c FROM y;
 294    """
 295
 296    # --- Autofilled ---
 297
 298    tokenizer_class = Tokenizer
 299    parser_class = Parser
 300    generator_class = Generator
 301
 302    # A trie of the time_mapping keys
 303    TIME_TRIE: t.Dict = {}
 304    FORMAT_TRIE: t.Dict = {}
 305
 306    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 307    INVERSE_TIME_TRIE: t.Dict = {}
 308
 309    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
 310
 311    # Delimiters for string literals and identifiers
 312    QUOTE_START = "'"
 313    QUOTE_END = "'"
 314    IDENTIFIER_START = '"'
 315    IDENTIFIER_END = '"'
 316
 317    # Delimiters for bit, hex, byte and unicode literals
 318    BIT_START: t.Optional[str] = None
 319    BIT_END: t.Optional[str] = None
 320    HEX_START: t.Optional[str] = None
 321    HEX_END: t.Optional[str] = None
 322    BYTE_START: t.Optional[str] = None
 323    BYTE_END: t.Optional[str] = None
 324    UNICODE_START: t.Optional[str] = None
 325    UNICODE_END: t.Optional[str] = None
 326
 327    # Separator of COPY statement parameters
 328    COPY_PARAMS_ARE_CSV = True
 329
 330    @classmethod
 331    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 332        """
 333        Look up a dialect in the global dialect registry and return it if it exists.
 334
 335        Args:
 336            dialect: The target dialect. If this is a string, it can be optionally followed by
 337                additional key-value pairs that are separated by commas and are used to specify
 338                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 339
 340        Example:
 341            >>> dialect = dialect_class = get_or_raise("duckdb")
 342            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 343
 344        Returns:
 345            The corresponding Dialect instance.
 346        """
 347
 348        if not dialect:
 349            return cls()
 350        if isinstance(dialect, _Dialect):
 351            return dialect()
 352        if isinstance(dialect, Dialect):
 353            return dialect
 354        if isinstance(dialect, str):
 355            try:
 356                dialect_name, *kv_pairs = dialect.split(",")
 357                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 358            except ValueError:
 359                raise ValueError(
 360                    f"Invalid dialect format: '{dialect}'. "
 361                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 362                )
 363
 364            result = cls.get(dialect_name.strip())
 365            if not result:
 366                from difflib import get_close_matches
 367
 368                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 369                if similar:
 370                    similar = f" Did you mean {similar}?"
 371
 372                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 373
 374            return result(**kwargs)
 375
 376        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 377
 378    @classmethod
 379    def format_time(
 380        cls, expression: t.Optional[str | exp.Expression]
 381    ) -> t.Optional[exp.Expression]:
 382        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 383        if isinstance(expression, str):
 384            return exp.Literal.string(
 385                # the time formats are quoted
 386                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 387            )
 388
 389        if expression and expression.is_string:
 390            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 391
 392        return expression
 393
 394    def __init__(self, **kwargs) -> None:
 395        normalization_strategy = kwargs.get("normalization_strategy")
 396
 397        if normalization_strategy is None:
 398            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 399        else:
 400            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 401
 402    def __eq__(self, other: t.Any) -> bool:
 403        # Does not currently take dialect state into account
 404        return type(self) == other
 405
 406    def __hash__(self) -> int:
 407        # Does not currently take dialect state into account
 408        return hash(type(self))
 409
 410    def normalize_identifier(self, expression: E) -> E:
 411        """
 412        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 413
 414        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 415        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 416        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 417        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 418
 419        There are also dialects like Spark, which are case-insensitive even when quotes are
 420        present, and dialects like MySQL, whose resolution rules match those employed by the
 421        underlying operating system, for example they may always be case-sensitive in Linux.
 422
 423        Finally, the normalization behavior of some engines can even be controlled through flags,
 424        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 425
 426        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 427        that it can analyze queries in the optimizer and successfully capture their semantics.
 428        """
 429        if (
 430            isinstance(expression, exp.Identifier)
 431            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 432            and (
 433                not expression.quoted
 434                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 435            )
 436        ):
 437            expression.set(
 438                "this",
 439                (
 440                    expression.this.upper()
 441                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 442                    else expression.this.lower()
 443                ),
 444            )
 445
 446        return expression
 447
 448    def case_sensitive(self, text: str) -> bool:
 449        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 450        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 451            return False
 452
 453        unsafe = (
 454            str.islower
 455            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 456            else str.isupper
 457        )
 458        return any(unsafe(char) for char in text)
 459
 460    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 461        """Checks if text can be identified given an identify option.
 462
 463        Args:
 464            text: The text to check.
 465            identify:
 466                `"always"` or `True`: Always returns `True`.
 467                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 468
 469        Returns:
 470            Whether the given text can be identified.
 471        """
 472        if identify is True or identify == "always":
 473            return True
 474
 475        if identify == "safe":
 476            return not self.case_sensitive(text)
 477
 478        return False
 479
 480    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 481        """
 482        Adds quotes to a given identifier.
 483
 484        Args:
 485            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 486            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 487                "unsafe", with respect to its characters and this dialect's normalization strategy.
 488        """
 489        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
 490            name = expression.this
 491            expression.set(
 492                "quoted",
 493                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 494            )
 495
 496        return expression
 497
 498    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 499        if isinstance(path, exp.Literal):
 500            path_text = path.name
 501            if path.is_number:
 502                path_text = f"[{path_text}]"
 503
 504            try:
 505                return parse_json_path(path_text)
 506            except ParseError as e:
 507                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 508
 509        return path
 510
 511    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 512        return self.parser(**opts).parse(self.tokenize(sql), sql)
 513
 514    def parse_into(
 515        self, expression_type: exp.IntoType, sql: str, **opts
 516    ) -> t.List[t.Optional[exp.Expression]]:
 517        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 518
 519    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 520        return self.generator(**opts).generate(expression, copy=copy)
 521
 522    def transpile(self, sql: str, **opts) -> t.List[str]:
 523        return [
 524            self.generate(expression, copy=False, **opts) if expression else ""
 525            for expression in self.parse(sql)
 526        ]
 527
 528    def tokenize(self, sql: str) -> t.List[Token]:
 529        return self.tokenizer.tokenize(sql)
 530
 531    @property
 532    def tokenizer(self) -> Tokenizer:
 533        if not hasattr(self, "_tokenizer"):
 534            self._tokenizer = self.tokenizer_class(dialect=self)
 535        return self._tokenizer
 536
 537    def parser(self, **opts) -> Parser:
 538        return self.parser_class(dialect=self, **opts)
 539
 540    def generator(self, **opts) -> Generator:
 541        return self.generator_class(dialect=self, **opts)
 542
 543
 544DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 545
 546
 547def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 548    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 549
 550
 551def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 552    if expression.args.get("accuracy"):
 553        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 554    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 555
 556
 557def if_sql(
 558    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 559) -> t.Callable[[Generator, exp.If], str]:
 560    def _if_sql(self: Generator, expression: exp.If) -> str:
 561        return self.func(
 562            name,
 563            expression.this,
 564            expression.args.get("true"),
 565            expression.args.get("false") or false_value,
 566        )
 567
 568    return _if_sql
 569
 570
 571def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 572    this = expression.this
 573    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 574        this.replace(exp.cast(this, exp.DataType.Type.JSON))
 575
 576    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 577
 578
 579def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 580    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
 581
 582
 583def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
 584    elem = seq_get(expression.expressions, 0)
 585    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
 586        return self.func("ARRAY", elem)
 587    return inline_array_sql(self, expression)
 588
 589
 590def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 591    return self.like_sql(
 592        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
 593    )
 594
 595
 596def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 597    zone = self.sql(expression, "this")
 598    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 599
 600
 601def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 602    if expression.args.get("recursive"):
 603        self.unsupported("Recursive CTEs are unsupported")
 604        expression.args["recursive"] = False
 605    return self.with_sql(expression)
 606
 607
 608def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 609    n = self.sql(expression, "this")
 610    d = self.sql(expression, "expression")
 611    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 612
 613
 614def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 615    self.unsupported("TABLESAMPLE unsupported")
 616    return self.sql(expression.this)
 617
 618
 619def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 620    self.unsupported("PIVOT unsupported")
 621    return ""
 622
 623
 624def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 625    return self.cast_sql(expression)
 626
 627
 628def no_comment_column_constraint_sql(
 629    self: Generator, expression: exp.CommentColumnConstraint
 630) -> str:
 631    self.unsupported("CommentColumnConstraint unsupported")
 632    return ""
 633
 634
 635def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 636    self.unsupported("MAP_FROM_ENTRIES unsupported")
 637    return ""
 638
 639
 640def str_position_sql(
 641    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
 642) -> str:
 643    this = self.sql(expression, "this")
 644    substr = self.sql(expression, "substr")
 645    position = self.sql(expression, "position")
 646    instance = expression.args.get("instance") if generate_instance else None
 647    position_offset = ""
 648
 649    if position:
 650        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
 651        this = self.func("SUBSTR", this, position)
 652        position_offset = f" + {position} - 1"
 653
 654    return self.func("STRPOS", this, substr, instance) + position_offset
 655
 656
 657def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 658    return (
 659        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 660    )
 661
 662
 663def var_map_sql(
 664    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 665) -> str:
 666    keys = expression.args["keys"]
 667    values = expression.args["values"]
 668
 669    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 670        self.unsupported("Cannot convert array columns into map.")
 671        return self.func(map_func_name, keys, values)
 672
 673    args = []
 674    for key, value in zip(keys.expressions, values.expressions):
 675        args.append(self.sql(key))
 676        args.append(self.sql(value))
 677
 678    return self.func(map_func_name, *args)
 679
 680
 681def build_formatted_time(
 682    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 683) -> t.Callable[[t.List], E]:
 684    """Helper used for time expressions.
 685
 686    Args:
 687        exp_class: the expression class to instantiate.
 688        dialect: target sql dialect.
 689        default: the default format, True being time.
 690
 691    Returns:
 692        A callable that can be used to return the appropriately formatted time expression.
 693    """
 694
 695    def _builder(args: t.List):
 696        return exp_class(
 697            this=seq_get(args, 0),
 698            format=Dialect[dialect].format_time(
 699                seq_get(args, 1)
 700                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 701            ),
 702        )
 703
 704    return _builder
 705
 706
 707def time_format(
 708    dialect: DialectType = None,
 709) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 710    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 711        """
 712        Returns the time format for a given expression, unless it's equivalent
 713        to the default time format of the dialect of interest.
 714        """
 715        time_format = self.format_time(expression)
 716        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 717
 718    return _time_format
 719
 720
 721def build_date_delta(
 722    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 723) -> t.Callable[[t.List], E]:
 724    def _builder(args: t.List) -> E:
 725        unit_based = len(args) == 3
 726        this = args[2] if unit_based else seq_get(args, 0)
 727        unit = args[0] if unit_based else exp.Literal.string("DAY")
 728        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 729        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 730
 731    return _builder
 732
 733
 734def build_date_delta_with_interval(
 735    expression_class: t.Type[E],
 736) -> t.Callable[[t.List], t.Optional[E]]:
 737    def _builder(args: t.List) -> t.Optional[E]:
 738        if len(args) < 2:
 739            return None
 740
 741        interval = args[1]
 742
 743        if not isinstance(interval, exp.Interval):
 744            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 745
 746        expression = interval.this
 747        if expression and expression.is_string:
 748            expression = exp.Literal.number(expression.this)
 749
 750        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
 751
 752    return _builder
 753
 754
 755def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 756    unit = seq_get(args, 0)
 757    this = seq_get(args, 1)
 758
 759    if isinstance(this, exp.Cast) and this.is_type("date"):
 760        return exp.DateTrunc(unit=unit, this=this)
 761    return exp.TimestampTrunc(this=this, unit=unit)
 762
 763
 764def date_add_interval_sql(
 765    data_type: str, kind: str
 766) -> t.Callable[[Generator, exp.Expression], str]:
 767    def func(self: Generator, expression: exp.Expression) -> str:
 768        this = self.sql(expression, "this")
 769        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
 770        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 771
 772    return func
 773
 774
 775def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 776    return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
 777
 778
 779def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 780    if not expression.expression:
 781        from sqlglot.optimizer.annotate_types import annotate_types
 782
 783        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
 784        return self.sql(exp.cast(expression.this, target_type))
 785    if expression.text("expression").lower() in TIMEZONES:
 786        return self.sql(
 787            exp.AtTimeZone(
 788                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
 789                zone=expression.expression,
 790            )
 791        )
 792    return self.func("TIMESTAMP", expression.this, expression.expression)
 793
 794
 795def locate_to_strposition(args: t.List) -> exp.Expression:
 796    return exp.StrPosition(
 797        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 798    )
 799
 800
 801def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 802    return self.func(
 803        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 804    )
 805
 806
 807def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 808    return self.sql(
 809        exp.Substring(
 810            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 811        )
 812    )
 813
 814
 815def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 816    return self.sql(
 817        exp.Substring(
 818            this=expression.this,
 819            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 820        )
 821    )
 822
 823
 824def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 825    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
 826
 827
 828def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 829    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
 830
 831
 832# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 833def encode_decode_sql(
 834    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 835) -> str:
 836    charset = expression.args.get("charset")
 837    if charset and charset.name.lower() != "utf-8":
 838        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 839
 840    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 841
 842
 843def min_or_least(self: Generator, expression: exp.Min) -> str:
 844    name = "LEAST" if expression.expressions else "MIN"
 845    return rename_func(name)(self, expression)
 846
 847
 848def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 849    name = "GREATEST" if expression.expressions else "MAX"
 850    return rename_func(name)(self, expression)
 851
 852
 853def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 854    cond = expression.this
 855
 856    if isinstance(expression.this, exp.Distinct):
 857        cond = expression.this.expressions[0]
 858        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 859
 860    return self.func("sum", exp.func("if", cond, 1, 0))
 861
 862
 863def trim_sql(self: Generator, expression: exp.Trim) -> str:
 864    target = self.sql(expression, "this")
 865    trim_type = self.sql(expression, "position")
 866    remove_chars = self.sql(expression, "expression")
 867    collation = self.sql(expression, "collation")
 868
 869    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 870    if not remove_chars and not collation:
 871        return self.trim_sql(expression)
 872
 873    trim_type = f"{trim_type} " if trim_type else ""
 874    remove_chars = f"{remove_chars} " if remove_chars else ""
 875    from_part = "FROM " if trim_type or remove_chars else ""
 876    collation = f" COLLATE {collation}" if collation else ""
 877    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 878
 879
 880def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 881    return self.func("STRPTIME", expression.this, self.format_time(expression))
 882
 883
 884def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 885    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 886
 887
 888def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 889    delim, *rest_args = expression.expressions
 890    return self.sql(
 891        reduce(
 892            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 893            rest_args,
 894        )
 895    )
 896
 897
 898def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 899    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 900    if bad_args:
 901        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 902
 903    return self.func(
 904        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 905    )
 906
 907
 908def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 909    bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
 910    if bad_args:
 911        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 912
 913    return self.func(
 914        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 915    )
 916
 917
 918def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 919    names = []
 920    for agg in aggregations:
 921        if isinstance(agg, exp.Alias):
 922            names.append(agg.alias)
 923        else:
 924            """
 925            This case corresponds to aggregations without aliases being used as suffixes
 926            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 927            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 928            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 929            """
 930            agg_all_unquoted = agg.transform(
 931                lambda node: (
 932                    exp.Identifier(this=node.name, quoted=False)
 933                    if isinstance(node, exp.Identifier)
 934                    else node
 935                )
 936            )
 937            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 938
 939    return names
 940
 941
 942def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 943    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 944
 945
 946# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 947def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 948    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 949
 950
 951def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 952    return self.func("MAX", expression.this)
 953
 954
 955def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 956    a = self.sql(expression.left)
 957    b = self.sql(expression.right)
 958    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 959
 960
 961def is_parse_json(expression: exp.Expression) -> bool:
 962    return isinstance(expression, exp.ParseJSON) or (
 963        isinstance(expression, exp.Cast) and expression.is_type("json")
 964    )
 965
 966
 967def isnull_to_is_null(args: t.List) -> exp.Expression:
 968    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 969
 970
 971def generatedasidentitycolumnconstraint_sql(
 972    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 973) -> str:
 974    start = self.sql(expression, "start") or "1"
 975    increment = self.sql(expression, "increment") or "1"
 976    return f"IDENTITY({start}, {increment})"
 977
 978
 979def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 980    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 981        if expression.args.get("count"):
 982            self.unsupported(f"Only two arguments are supported in function {name}.")
 983
 984        return self.func(name, expression.this, expression.expression)
 985
 986    return _arg_max_or_min_sql
 987
 988
 989def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 990    this = expression.this.copy()
 991
 992    return_type = expression.return_type
 993    if return_type.is_type(exp.DataType.Type.DATE):
 994        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
 995        # can truncate timestamp strings, because some dialects can't cast them to DATE
 996        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
 997
 998    expression.this.replace(exp.cast(this, return_type))
 999    return expression
1000
1001
1002def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1003    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1004        if cast and isinstance(expression, exp.TsOrDsAdd):
1005            expression = ts_or_ds_add_cast(expression)
1006
1007        return self.func(
1008            name,
1009            unit_to_var(expression),
1010            expression.expression,
1011            expression.this,
1012        )
1013
1014    return _delta_sql
1015
1016
1017def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1018    unit = expression.args.get("unit")
1019
1020    if isinstance(unit, exp.Placeholder):
1021        return unit
1022    if unit:
1023        return exp.Literal.string(unit.name)
1024    return exp.Literal.string(default) if default else None
1025
1026
1027def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1028    unit = expression.args.get("unit")
1029
1030    if isinstance(unit, (exp.Var, exp.Placeholder)):
1031        return unit
1032    return exp.Var(this=default) if default else None
1033
1034
1035def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1036    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1037    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1038    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1039
1040    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1041
1042
1043def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1044    """Remove table refs from columns in when statements."""
1045    alias = expression.this.args.get("alias")
1046
1047    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1048        return self.dialect.normalize_identifier(identifier).name if identifier else None
1049
1050    targets = {normalize(expression.this.this)}
1051
1052    if alias:
1053        targets.add(normalize(alias.this))
1054
1055    for when in expression.expressions:
1056        when.transform(
1057            lambda node: (
1058                exp.column(node.this)
1059                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1060                else node
1061            ),
1062            copy=False,
1063        )
1064
1065    return self.merge_sql(expression)
1066
1067
1068def build_json_extract_path(
1069    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1070) -> t.Callable[[t.List], F]:
1071    def _builder(args: t.List) -> F:
1072        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1073        for arg in args[1:]:
1074            if not isinstance(arg, exp.Literal):
1075                # We use the fallback parser because we can't really transpile non-literals safely
1076                return expr_type.from_arg_list(args)
1077
1078            text = arg.name
1079            if is_int(text):
1080                index = int(text)
1081                segments.append(
1082                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1083                )
1084            else:
1085                segments.append(exp.JSONPathKey(this=text))
1086
1087        # This is done to avoid failing in the expression validator due to the arg count
1088        del args[2:]
1089        return expr_type(
1090            this=seq_get(args, 0),
1091            expression=exp.JSONPath(expressions=segments),
1092            only_json_types=arrow_req_json_type,
1093        )
1094
1095    return _builder
1096
1097
1098def json_extract_segments(
1099    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1100) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1101    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1102        path = expression.expression
1103        if not isinstance(path, exp.JSONPath):
1104            return rename_func(name)(self, expression)
1105
1106        segments = []
1107        for segment in path.expressions:
1108            path = self.sql(segment)
1109            if path:
1110                if isinstance(segment, exp.JSONPathPart) and (
1111                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1112                ):
1113                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1114
1115                segments.append(path)
1116
1117        if op:
1118            return f" {op} ".join([self.sql(expression.this), *segments])
1119        return self.func(name, expression.this, *segments)
1120
1121    return _json_extract_segments
1122
1123
1124def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1125    if isinstance(expression.this, exp.JSONPathWildcard):
1126        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1127
1128    return expression.name
1129
1130
1131def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1132    cond = expression.expression
1133    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1134        alias = cond.expressions[0]
1135        cond = cond.this
1136    elif isinstance(cond, exp.Predicate):
1137        alias = "_u"
1138    else:
1139        self.unsupported("Unsupported filter condition")
1140        return ""
1141
1142    unnest = exp.Unnest(expressions=[expression.this])
1143    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1144    return self.sql(exp.Array(expressions=[filtered]))
1145
1146
1147def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
1148    return self.func(
1149        "TO_NUMBER",
1150        expression.this,
1151        expression.args.get("format"),
1152        expression.args.get("nlsparam"),
1153    )
1154
1155
1156def build_default_decimal_type(
1157    precision: t.Optional[int] = None, scale: t.Optional[int] = None
1158) -> t.Callable[[exp.DataType], exp.DataType]:
1159    def _builder(dtype: exp.DataType) -> exp.DataType:
1160        if dtype.expressions or precision is None:
1161            return dtype
1162
1163        params = f"{precision}{f', {scale}' if scale is not None else ''}"
1164        return exp.DataType.build(f"DECIMAL({params})")
1165
1166    return _builder
logger = <Logger sqlglot (WARNING)>
class Dialects(builtins.str, enum.Enum):
30class Dialects(str, Enum):
31    """Dialects supported by SQLGLot."""
32
33    DIALECT = ""
34
35    ATHENA = "athena"
36    BIGQUERY = "bigquery"
37    CLICKHOUSE = "clickhouse"
38    DATABRICKS = "databricks"
39    DORIS = "doris"
40    DRILL = "drill"
41    DUCKDB = "duckdb"
42    HIVE = "hive"
43    MYSQL = "mysql"
44    ORACLE = "oracle"
45    POSTGRES = "postgres"
46    PRESTO = "presto"
47    PRQL = "prql"
48    REDSHIFT = "redshift"
49    SNOWFLAKE = "snowflake"
50    SPARK = "spark"
51    SPARK2 = "spark2"
52    SQLITE = "sqlite"
53    STARROCKS = "starrocks"
54    TABLEAU = "tableau"
55    TERADATA = "teradata"
56    TRINO = "trino"
57    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
ATHENA = <Dialects.ATHENA: 'athena'>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
PRQL = <Dialects.PRQL: 'prql'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
60class NormalizationStrategy(str, AutoName):
61    """Specifies the strategy according to which identifiers should be normalized."""
62
63    LOWERCASE = auto()
64    """Unquoted identifiers are lowercased."""
65
66    UPPERCASE = auto()
67    """Unquoted identifiers are uppercased."""
68
69    CASE_SENSITIVE = auto()
70    """Always case-sensitive, regardless of quotes."""
71
72    CASE_INSENSITIVE = auto()
73    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
184class Dialect(metaclass=_Dialect):
185    INDEX_OFFSET = 0
186    """The base index offset for arrays."""
187
188    WEEK_OFFSET = 0
189    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
190
191    UNNEST_COLUMN_ONLY = False
192    """Whether `UNNEST` table aliases are treated as column aliases."""
193
194    ALIAS_POST_TABLESAMPLE = False
195    """Whether the table alias comes after tablesample."""
196
197    TABLESAMPLE_SIZE_IS_PERCENT = False
198    """Whether a size in the table sample clause represents percentage."""
199
200    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
201    """Specifies the strategy according to which identifiers should be normalized."""
202
203    IDENTIFIERS_CAN_START_WITH_DIGIT = False
204    """Whether an unquoted identifier can start with a digit."""
205
206    DPIPE_IS_STRING_CONCAT = True
207    """Whether the DPIPE token (`||`) is a string concatenation operator."""
208
209    STRICT_STRING_CONCAT = False
210    """Whether `CONCAT`'s arguments must be strings."""
211
212    SUPPORTS_USER_DEFINED_TYPES = True
213    """Whether user-defined data types are supported."""
214
215    SUPPORTS_SEMI_ANTI_JOIN = True
216    """Whether `SEMI` or `ANTI` joins are supported."""
217
218    NORMALIZE_FUNCTIONS: bool | str = "upper"
219    """
220    Determines how function names are going to be normalized.
221    Possible values:
222        "upper" or True: Convert names to uppercase.
223        "lower": Convert names to lowercase.
224        False: Disables function name normalization.
225    """
226
227    LOG_BASE_FIRST: t.Optional[bool] = True
228    """
229    Whether the base comes first in the `LOG` function.
230    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
231    """
232
233    NULL_ORDERING = "nulls_are_small"
234    """
235    Default `NULL` ordering method to use if not explicitly set.
236    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
237    """
238
239    TYPED_DIVISION = False
240    """
241    Whether the behavior of `a / b` depends on the types of `a` and `b`.
242    False means `a / b` is always float division.
243    True means `a / b` is integer division if both `a` and `b` are integers.
244    """
245
246    SAFE_DIVISION = False
247    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
248
249    CONCAT_COALESCE = False
250    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
251
252    HEX_LOWERCASE = False
253    """Whether the `HEX` function returns a lowercase hexadecimal string."""
254
255    DATE_FORMAT = "'%Y-%m-%d'"
256    DATEINT_FORMAT = "'%Y%m%d'"
257    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
258
259    TIME_MAPPING: t.Dict[str, str] = {}
260    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
261
262    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
263    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
264    FORMAT_MAPPING: t.Dict[str, str] = {}
265    """
266    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
267    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
268    """
269
270    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
271    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
272
273    PSEUDOCOLUMNS: t.Set[str] = set()
274    """
275    Columns that are auto-generated by the engine corresponding to this dialect.
276    For example, such columns may be excluded from `SELECT *` queries.
277    """
278
279    PREFER_CTE_ALIAS_COLUMN = False
280    """
281    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
282    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
283    any projection aliases in the subquery.
284
285    For example,
286        WITH y(c) AS (
287            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
288        ) SELECT c FROM y;
289
290        will be rewritten as
291
292        WITH y(c) AS (
293            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
294        ) SELECT c FROM y;
295    """
296
297    # --- Autofilled ---
298
299    tokenizer_class = Tokenizer
300    parser_class = Parser
301    generator_class = Generator
302
303    # A trie of the time_mapping keys
304    TIME_TRIE: t.Dict = {}
305    FORMAT_TRIE: t.Dict = {}
306
307    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
308    INVERSE_TIME_TRIE: t.Dict = {}
309
310    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
311
312    # Delimiters for string literals and identifiers
313    QUOTE_START = "'"
314    QUOTE_END = "'"
315    IDENTIFIER_START = '"'
316    IDENTIFIER_END = '"'
317
318    # Delimiters for bit, hex, byte and unicode literals
319    BIT_START: t.Optional[str] = None
320    BIT_END: t.Optional[str] = None
321    HEX_START: t.Optional[str] = None
322    HEX_END: t.Optional[str] = None
323    BYTE_START: t.Optional[str] = None
324    BYTE_END: t.Optional[str] = None
325    UNICODE_START: t.Optional[str] = None
326    UNICODE_END: t.Optional[str] = None
327
328    # Separator of COPY statement parameters
329    COPY_PARAMS_ARE_CSV = True
330
331    @classmethod
332    def get_or_raise(cls, dialect: DialectType) -> Dialect:
333        """
334        Look up a dialect in the global dialect registry and return it if it exists.
335
336        Args:
337            dialect: The target dialect. If this is a string, it can be optionally followed by
338                additional key-value pairs that are separated by commas and are used to specify
339                dialect settings, such as whether the dialect's identifiers are case-sensitive.
340
341        Example:
342            >>> dialect = dialect_class = get_or_raise("duckdb")
343            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
344
345        Returns:
346            The corresponding Dialect instance.
347        """
348
349        if not dialect:
350            return cls()
351        if isinstance(dialect, _Dialect):
352            return dialect()
353        if isinstance(dialect, Dialect):
354            return dialect
355        if isinstance(dialect, str):
356            try:
357                dialect_name, *kv_pairs = dialect.split(",")
358                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
359            except ValueError:
360                raise ValueError(
361                    f"Invalid dialect format: '{dialect}'. "
362                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
363                )
364
365            result = cls.get(dialect_name.strip())
366            if not result:
367                from difflib import get_close_matches
368
369                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
370                if similar:
371                    similar = f" Did you mean {similar}?"
372
373                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
374
375            return result(**kwargs)
376
377        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
378
379    @classmethod
380    def format_time(
381        cls, expression: t.Optional[str | exp.Expression]
382    ) -> t.Optional[exp.Expression]:
383        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
384        if isinstance(expression, str):
385            return exp.Literal.string(
386                # the time formats are quoted
387                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
388            )
389
390        if expression and expression.is_string:
391            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
392
393        return expression
394
395    def __init__(self, **kwargs) -> None:
396        normalization_strategy = kwargs.get("normalization_strategy")
397
398        if normalization_strategy is None:
399            self.normalization_strategy = self.NORMALIZATION_STRATEGY
400        else:
401            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
402
403    def __eq__(self, other: t.Any) -> bool:
404        # Does not currently take dialect state into account
405        return type(self) == other
406
407    def __hash__(self) -> int:
408        # Does not currently take dialect state into account
409        return hash(type(self))
410
411    def normalize_identifier(self, expression: E) -> E:
412        """
413        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
414
415        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
416        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
417        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
418        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
419
420        There are also dialects like Spark, which are case-insensitive even when quotes are
421        present, and dialects like MySQL, whose resolution rules match those employed by the
422        underlying operating system, for example they may always be case-sensitive in Linux.
423
424        Finally, the normalization behavior of some engines can even be controlled through flags,
425        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
426
427        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
428        that it can analyze queries in the optimizer and successfully capture their semantics.
429        """
430        if (
431            isinstance(expression, exp.Identifier)
432            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
433            and (
434                not expression.quoted
435                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
436            )
437        ):
438            expression.set(
439                "this",
440                (
441                    expression.this.upper()
442                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
443                    else expression.this.lower()
444                ),
445            )
446
447        return expression
448
449    def case_sensitive(self, text: str) -> bool:
450        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
451        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
452            return False
453
454        unsafe = (
455            str.islower
456            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
457            else str.isupper
458        )
459        return any(unsafe(char) for char in text)
460
461    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
462        """Checks if text can be identified given an identify option.
463
464        Args:
465            text: The text to check.
466            identify:
467                `"always"` or `True`: Always returns `True`.
468                `"safe"`: Only returns `True` if the identifier is case-insensitive.
469
470        Returns:
471            Whether the given text can be identified.
472        """
473        if identify is True or identify == "always":
474            return True
475
476        if identify == "safe":
477            return not self.case_sensitive(text)
478
479        return False
480
481    def quote_identifier(self, expression: E, identify: bool = True) -> E:
482        """
483        Adds quotes to a given identifier.
484
485        Args:
486            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
487            identify: If set to `False`, the quotes will only be added if the identifier is deemed
488                "unsafe", with respect to its characters and this dialect's normalization strategy.
489        """
490        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
491            name = expression.this
492            expression.set(
493                "quoted",
494                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
495            )
496
497        return expression
498
499    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
500        if isinstance(path, exp.Literal):
501            path_text = path.name
502            if path.is_number:
503                path_text = f"[{path_text}]"
504
505            try:
506                return parse_json_path(path_text)
507            except ParseError as e:
508                logger.warning(f"Invalid JSON path syntax. {str(e)}")
509
510        return path
511
512    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
513        return self.parser(**opts).parse(self.tokenize(sql), sql)
514
515    def parse_into(
516        self, expression_type: exp.IntoType, sql: str, **opts
517    ) -> t.List[t.Optional[exp.Expression]]:
518        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
519
520    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
521        return self.generator(**opts).generate(expression, copy=copy)
522
523    def transpile(self, sql: str, **opts) -> t.List[str]:
524        return [
525            self.generate(expression, copy=False, **opts) if expression else ""
526            for expression in self.parse(sql)
527        ]
528
529    def tokenize(self, sql: str) -> t.List[Token]:
530        return self.tokenizer.tokenize(sql)
531
532    @property
533    def tokenizer(self) -> Tokenizer:
534        if not hasattr(self, "_tokenizer"):
535            self._tokenizer = self.tokenizer_class(dialect=self)
536        return self._tokenizer
537
538    def parser(self, **opts) -> Parser:
539        return self.parser_class(dialect=self, **opts)
540
541    def generator(self, **opts) -> Generator:
542        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
395    def __init__(self, **kwargs) -> None:
396        normalization_strategy = kwargs.get("normalization_strategy")
397
398        if normalization_strategy is None:
399            self.normalization_strategy = self.NORMALIZATION_STRATEGY
400        else:
401            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

The base index offset for arrays.

WEEK_OFFSET = 0

First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Whether UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Whether the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Whether a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Whether an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Whether the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Whether CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Whether user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Whether SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

Possible values:

"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.

LOG_BASE_FIRST: Optional[bool] = True

Whether the base comes first in the LOG function. Possible values: True, False, None (two arguments are not supported by LOG)

NULL_ORDERING = 'nulls_are_small'

Default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

HEX_LOWERCASE = False

Whether the HEX function returns a lowercase hexadecimal string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime formats.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

UNESCAPED_SEQUENCES: Dict[str, str] = {}

Mapping of an escaped sequence (\n) to its unescaped version ( ).

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
ESCAPED_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
COPY_PARAMS_ARE_CSV = True
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
331    @classmethod
332    def get_or_raise(cls, dialect: DialectType) -> Dialect:
333        """
334        Look up a dialect in the global dialect registry and return it if it exists.
335
336        Args:
337            dialect: The target dialect. If this is a string, it can be optionally followed by
338                additional key-value pairs that are separated by commas and are used to specify
339                dialect settings, such as whether the dialect's identifiers are case-sensitive.
340
341        Example:
342            >>> dialect = dialect_class = get_or_raise("duckdb")
343            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
344
345        Returns:
346            The corresponding Dialect instance.
347        """
348
349        if not dialect:
350            return cls()
351        if isinstance(dialect, _Dialect):
352            return dialect()
353        if isinstance(dialect, Dialect):
354            return dialect
355        if isinstance(dialect, str):
356            try:
357                dialect_name, *kv_pairs = dialect.split(",")
358                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
359            except ValueError:
360                raise ValueError(
361                    f"Invalid dialect format: '{dialect}'. "
362                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
363                )
364
365            result = cls.get(dialect_name.strip())
366            if not result:
367                from difflib import get_close_matches
368
369                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
370                if similar:
371                    similar = f" Did you mean {similar}?"
372
373                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
374
375            return result(**kwargs)
376
377        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
379    @classmethod
380    def format_time(
381        cls, expression: t.Optional[str | exp.Expression]
382    ) -> t.Optional[exp.Expression]:
383        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
384        if isinstance(expression, str):
385            return exp.Literal.string(
386                # the time formats are quoted
387                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
388            )
389
390        if expression and expression.is_string:
391            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
392
393        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
411    def normalize_identifier(self, expression: E) -> E:
412        """
413        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
414
415        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
416        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
417        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
418        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
419
420        There are also dialects like Spark, which are case-insensitive even when quotes are
421        present, and dialects like MySQL, whose resolution rules match those employed by the
422        underlying operating system, for example they may always be case-sensitive in Linux.
423
424        Finally, the normalization behavior of some engines can even be controlled through flags,
425        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
426
427        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
428        that it can analyze queries in the optimizer and successfully capture their semantics.
429        """
430        if (
431            isinstance(expression, exp.Identifier)
432            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
433            and (
434                not expression.quoted
435                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
436            )
437        ):
438            expression.set(
439                "this",
440                (
441                    expression.this.upper()
442                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
443                    else expression.this.lower()
444                ),
445            )
446
447        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
449    def case_sensitive(self, text: str) -> bool:
450        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
451        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
452            return False
453
454        unsafe = (
455            str.islower
456            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
457            else str.isupper
458        )
459        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
461    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
462        """Checks if text can be identified given an identify option.
463
464        Args:
465            text: The text to check.
466            identify:
467                `"always"` or `True`: Always returns `True`.
468                `"safe"`: Only returns `True` if the identifier is case-insensitive.
469
470        Returns:
471            Whether the given text can be identified.
472        """
473        if identify is True or identify == "always":
474            return True
475
476        if identify == "safe":
477            return not self.case_sensitive(text)
478
479        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
481    def quote_identifier(self, expression: E, identify: bool = True) -> E:
482        """
483        Adds quotes to a given identifier.
484
485        Args:
486            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
487            identify: If set to `False`, the quotes will only be added if the identifier is deemed
488                "unsafe", with respect to its characters and this dialect's normalization strategy.
489        """
490        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
491            name = expression.this
492            expression.set(
493                "quoted",
494                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
495            )
496
497        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
499    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
500        if isinstance(path, exp.Literal):
501            path_text = path.name
502            if path.is_number:
503                path_text = f"[{path_text}]"
504
505            try:
506                return parse_json_path(path_text)
507            except ParseError as e:
508                logger.warning(f"Invalid JSON path syntax. {str(e)}")
509
510        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
512    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
513        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
515    def parse_into(
516        self, expression_type: exp.IntoType, sql: str, **opts
517    ) -> t.List[t.Optional[exp.Expression]]:
518        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
520    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
521        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
523    def transpile(self, sql: str, **opts) -> t.List[str]:
524        return [
525            self.generate(expression, copy=False, **opts) if expression else ""
526            for expression in self.parse(sql)
527        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
529    def tokenize(self, sql: str) -> t.List[Token]:
530        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
532    @property
533    def tokenizer(self) -> Tokenizer:
534        if not hasattr(self, "_tokenizer"):
535            self._tokenizer = self.tokenizer_class(dialect=self)
536        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
538    def parser(self, **opts) -> Parser:
539        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
541    def generator(self, **opts) -> Generator:
542        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
548def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
549    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
552def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
553    if expression.args.get("accuracy"):
554        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
555    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
558def if_sql(
559    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
560) -> t.Callable[[Generator, exp.If], str]:
561    def _if_sql(self: Generator, expression: exp.If) -> str:
562        return self.func(
563            name,
564            expression.this,
565            expression.args.get("true"),
566            expression.args.get("false") or false_value,
567        )
568
569    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
572def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
573    this = expression.this
574    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
575        this.replace(exp.cast(this, exp.DataType.Type.JSON))
576
577    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
580def inline_array_sql(self: Generator, expression: exp.Array) -> str:
581    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
def inline_array_unless_query( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
584def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
585    elem = seq_get(expression.expressions, 0)
586    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
587        return self.func("ARRAY", elem)
588    return inline_array_sql(self, expression)
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
591def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
592    return self.like_sql(
593        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
594    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
597def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
598    zone = self.sql(expression, "this")
599    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
602def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
603    if expression.args.get("recursive"):
604        self.unsupported("Recursive CTEs are unsupported")
605        expression.args["recursive"] = False
606    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
609def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
610    n = self.sql(expression, "this")
611    d = self.sql(expression, "expression")
612    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
615def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
616    self.unsupported("TABLESAMPLE unsupported")
617    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
620def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
621    self.unsupported("PIVOT unsupported")
622    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
625def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
626    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
629def no_comment_column_constraint_sql(
630    self: Generator, expression: exp.CommentColumnConstraint
631) -> str:
632    self.unsupported("CommentColumnConstraint unsupported")
633    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
636def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
637    self.unsupported("MAP_FROM_ENTRIES unsupported")
638    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition, generate_instance: bool = False) -> str:
641def str_position_sql(
642    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
643) -> str:
644    this = self.sql(expression, "this")
645    substr = self.sql(expression, "substr")
646    position = self.sql(expression, "position")
647    instance = expression.args.get("instance") if generate_instance else None
648    position_offset = ""
649
650    if position:
651        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
652        this = self.func("SUBSTR", this, position)
653        position_offset = f" + {position} - 1"
654
655    return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
658def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
659    return (
660        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
661    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
664def var_map_sql(
665    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
666) -> str:
667    keys = expression.args["keys"]
668    values = expression.args["values"]
669
670    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
671        self.unsupported("Cannot convert array columns into map.")
672        return self.func(map_func_name, keys, values)
673
674    args = []
675    for key, value in zip(keys.expressions, values.expressions):
676        args.append(self.sql(key))
677        args.append(self.sql(value))
678
679    return self.func(map_func_name, *args)
def build_formatted_time( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
682def build_formatted_time(
683    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
684) -> t.Callable[[t.List], E]:
685    """Helper used for time expressions.
686
687    Args:
688        exp_class: the expression class to instantiate.
689        dialect: target sql dialect.
690        default: the default format, True being time.
691
692    Returns:
693        A callable that can be used to return the appropriately formatted time expression.
694    """
695
696    def _builder(args: t.List):
697        return exp_class(
698            this=seq_get(args, 0),
699            format=Dialect[dialect].format_time(
700                seq_get(args, 1)
701                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
702            ),
703        )
704
705    return _builder

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
708def time_format(
709    dialect: DialectType = None,
710) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
711    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
712        """
713        Returns the time format for a given expression, unless it's equivalent
714        to the default time format of the dialect of interest.
715        """
716        time_format = self.format_time(expression)
717        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
718
719    return _time_format
def build_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
722def build_date_delta(
723    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
724) -> t.Callable[[t.List], E]:
725    def _builder(args: t.List) -> E:
726        unit_based = len(args) == 3
727        this = args[2] if unit_based else seq_get(args, 0)
728        unit = args[0] if unit_based else exp.Literal.string("DAY")
729        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
730        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
731
732    return _builder
def build_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
735def build_date_delta_with_interval(
736    expression_class: t.Type[E],
737) -> t.Callable[[t.List], t.Optional[E]]:
738    def _builder(args: t.List) -> t.Optional[E]:
739        if len(args) < 2:
740            return None
741
742        interval = args[1]
743
744        if not isinstance(interval, exp.Interval):
745            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
746
747        expression = interval.this
748        if expression and expression.is_string:
749            expression = exp.Literal.number(expression.this)
750
751        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
752
753    return _builder
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
756def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
757    unit = seq_get(args, 0)
758    this = seq_get(args, 1)
759
760    if isinstance(this, exp.Cast) and this.is_type("date"):
761        return exp.DateTrunc(unit=unit, this=this)
762    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
765def date_add_interval_sql(
766    data_type: str, kind: str
767) -> t.Callable[[Generator, exp.Expression], str]:
768    def func(self: Generator, expression: exp.Expression) -> str:
769        this = self.sql(expression, "this")
770        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
771        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
772
773    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
776def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
777    return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
780def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
781    if not expression.expression:
782        from sqlglot.optimizer.annotate_types import annotate_types
783
784        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
785        return self.sql(exp.cast(expression.this, target_type))
786    if expression.text("expression").lower() in TIMEZONES:
787        return self.sql(
788            exp.AtTimeZone(
789                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
790                zone=expression.expression,
791            )
792        )
793    return self.func("TIMESTAMP", expression.this, expression.expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
796def locate_to_strposition(args: t.List) -> exp.Expression:
797    return exp.StrPosition(
798        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
799    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
802def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
803    return self.func(
804        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
805    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
808def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
809    return self.sql(
810        exp.Substring(
811            this=expression.this, start=exp.Literal.number(1), length=expression.expression
812        )
813    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
816def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
817    return self.sql(
818        exp.Substring(
819            this=expression.this,
820            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
821        )
822    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
825def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
826    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
829def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
830    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
834def encode_decode_sql(
835    self: Generator, expression: exp.Expression, name: str, replace: bool = True
836) -> str:
837    charset = expression.args.get("charset")
838    if charset and charset.name.lower() != "utf-8":
839        self.unsupported(f"Expected utf-8 character set, got {charset}.")
840
841    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
844def min_or_least(self: Generator, expression: exp.Min) -> str:
845    name = "LEAST" if expression.expressions else "MIN"
846    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
849def max_or_greatest(self: Generator, expression: exp.Max) -> str:
850    name = "GREATEST" if expression.expressions else "MAX"
851    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
854def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
855    cond = expression.this
856
857    if isinstance(expression.this, exp.Distinct):
858        cond = expression.this.expressions[0]
859        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
860
861    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
864def trim_sql(self: Generator, expression: exp.Trim) -> str:
865    target = self.sql(expression, "this")
866    trim_type = self.sql(expression, "position")
867    remove_chars = self.sql(expression, "expression")
868    collation = self.sql(expression, "collation")
869
870    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
871    if not remove_chars and not collation:
872        return self.trim_sql(expression)
873
874    trim_type = f"{trim_type} " if trim_type else ""
875    remove_chars = f"{remove_chars} " if remove_chars else ""
876    from_part = "FROM " if trim_type or remove_chars else ""
877    collation = f" COLLATE {collation}" if collation else ""
878    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
881def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
882    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
885def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
886    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
889def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
890    delim, *rest_args = expression.expressions
891    return self.sql(
892        reduce(
893            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
894            rest_args,
895        )
896    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
899def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
900    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
901    if bad_args:
902        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
903
904    return self.func(
905        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
906    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
909def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
910    bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
911    if bad_args:
912        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
913
914    return self.func(
915        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
916    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
919def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
920    names = []
921    for agg in aggregations:
922        if isinstance(agg, exp.Alias):
923            names.append(agg.alias)
924        else:
925            """
926            This case corresponds to aggregations without aliases being used as suffixes
927            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
928            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
929            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
930            """
931            agg_all_unquoted = agg.transform(
932                lambda node: (
933                    exp.Identifier(this=node.name, quoted=False)
934                    if isinstance(node, exp.Identifier)
935                    else node
936                )
937            )
938            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
939
940    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
943def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
944    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def build_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
948def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
949    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
952def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
953    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
956def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
957    a = self.sql(expression.left)
958    b = self.sql(expression.right)
959    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
962def is_parse_json(expression: exp.Expression) -> bool:
963    return isinstance(expression, exp.ParseJSON) or (
964        isinstance(expression, exp.Cast) and expression.is_type("json")
965    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
968def isnull_to_is_null(args: t.List) -> exp.Expression:
969    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
972def generatedasidentitycolumnconstraint_sql(
973    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
974) -> str:
975    start = self.sql(expression, "start") or "1"
976    increment = self.sql(expression, "increment") or "1"
977    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
980def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
981    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
982        if expression.args.get("count"):
983            self.unsupported(f"Only two arguments are supported in function {name}.")
984
985        return self.func(name, expression.this, expression.expression)
986
987    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
 990def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 991    this = expression.this.copy()
 992
 993    return_type = expression.return_type
 994    if return_type.is_type(exp.DataType.Type.DATE):
 995        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
 996        # can truncate timestamp strings, because some dialects can't cast them to DATE
 997        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
 998
 999    expression.this.replace(exp.cast(this, return_type))
1000    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
1003def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1004    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1005        if cast and isinstance(expression, exp.TsOrDsAdd):
1006            expression = ts_or_ds_add_cast(expression)
1007
1008        return self.func(
1009            name,
1010            unit_to_var(expression),
1011            expression.expression,
1012            expression.this,
1013        )
1014
1015    return _delta_sql
def unit_to_str( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1018def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1019    unit = expression.args.get("unit")
1020
1021    if isinstance(unit, exp.Placeholder):
1022        return unit
1023    if unit:
1024        return exp.Literal.string(unit.name)
1025    return exp.Literal.string(default) if default else None
def unit_to_var( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1028def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1029    unit = expression.args.get("unit")
1030
1031    if isinstance(unit, (exp.Var, exp.Placeholder)):
1032        return unit
1033    return exp.Var(this=default) if default else None
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
1036def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1037    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1038    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1039    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1040
1041    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
1044def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1045    """Remove table refs from columns in when statements."""
1046    alias = expression.this.args.get("alias")
1047
1048    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1049        return self.dialect.normalize_identifier(identifier).name if identifier else None
1050
1051    targets = {normalize(expression.this.this)}
1052
1053    if alias:
1054        targets.add(normalize(alias.this))
1055
1056    for when in expression.expressions:
1057        when.transform(
1058            lambda node: (
1059                exp.column(node.this)
1060                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1061                else node
1062            ),
1063            copy=False,
1064        )
1065
1066    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def build_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False) -> Callable[[List], ~F]:
1069def build_json_extract_path(
1070    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1071) -> t.Callable[[t.List], F]:
1072    def _builder(args: t.List) -> F:
1073        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1074        for arg in args[1:]:
1075            if not isinstance(arg, exp.Literal):
1076                # We use the fallback parser because we can't really transpile non-literals safely
1077                return expr_type.from_arg_list(args)
1078
1079            text = arg.name
1080            if is_int(text):
1081                index = int(text)
1082                segments.append(
1083                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1084                )
1085            else:
1086                segments.append(exp.JSONPathKey(this=text))
1087
1088        # This is done to avoid failing in the expression validator due to the arg count
1089        del args[2:]
1090        return expr_type(
1091            this=seq_get(args, 0),
1092            expression=exp.JSONPath(expressions=segments),
1093            only_json_types=arrow_req_json_type,
1094        )
1095
1096    return _builder
def json_extract_segments( name: str, quoted_index: bool = True, op: Optional[str] = None) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1099def json_extract_segments(
1100    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1101) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1102    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1103        path = expression.expression
1104        if not isinstance(path, exp.JSONPath):
1105            return rename_func(name)(self, expression)
1106
1107        segments = []
1108        for segment in path.expressions:
1109            path = self.sql(segment)
1110            if path:
1111                if isinstance(segment, exp.JSONPathPart) and (
1112                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1113                ):
1114                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1115
1116                segments.append(path)
1117
1118        if op:
1119            return f" {op} ".join([self.sql(expression.this), *segments])
1120        return self.func(name, expression.this, *segments)
1121
1122    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1125def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1126    if isinstance(expression.this, exp.JSONPathWildcard):
1127        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1128
1129    return expression.name
def filter_array_using_unnest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ArrayFilter) -> str:
1132def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1133    cond = expression.expression
1134    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1135        alias = cond.expressions[0]
1136        cond = cond.this
1137    elif isinstance(cond, exp.Predicate):
1138        alias = "_u"
1139    else:
1140        self.unsupported("Unsupported filter condition")
1141        return ""
1142
1143    unnest = exp.Unnest(expressions=[expression.this])
1144    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1145    return self.sql(exp.Array(expressions=[filtered]))
def to_number_with_nls_param( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ToNumber) -> str:
1148def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
1149    return self.func(
1150        "TO_NUMBER",
1151        expression.this,
1152        expression.args.get("format"),
1153        expression.args.get("nlsparam"),
1154    )
def build_default_decimal_type( precision: Optional[int] = None, scale: Optional[int] = None) -> Callable[[sqlglot.expressions.DataType], sqlglot.expressions.DataType]:
1157def build_default_decimal_type(
1158    precision: t.Optional[int] = None, scale: t.Optional[int] = None
1159) -> t.Callable[[exp.DataType], exp.DataType]:
1160    def _builder(dtype: exp.DataType) -> exp.DataType:
1161        if dtype.expressions or precision is None:
1162            return dtype
1163
1164        params = f"{precision}{f', {scale}' if scale is not None else ''}"
1165        return exp.DataType.build(f"DECIMAL({params})")
1166
1167    return _builder