diff --git a/CHANGELOG.md b/CHANGELOG.md index 86a7dd72c..d754b784d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,25 @@ # Changelog +### v1.10.1 + +#### Features + +- Add `dbt_sqlserver_enable_safe_type_expansion` behaviour flag to allow safe column type widening during schema expansion: `varchar` → `nvarchar`, integer family promotions (`bit` → `tinyint` → `smallint` → `int` → `bigint`), and `numeric`/`decimal` precision/scale upgrades. Gated by the per-model `column_type_expansion_max_rows` config (default 1,000,000 rows). See [#699](https://github.com/dbt-msft/dbt-sqlserver/issues/699). +- Add `prefer_single_alter_column` model config to use a single `ALTER COLUMN` statement instead of the add+update+drop+rename pattern when altering column types on tables. +- Add `string_type_instance()` to preserve the NVARCHAR/NCHAR type family during column expansion, fixing incorrect promotion of NVARCHAR/NCHAR to VARCHAR. +- Add `tinyint` and `bit` to the `is_integer()` type list for correct type detection. + +#### Bugfixes + +- Fix catalog generation for NVARCHAR/NCHAR columns: use `user_type_id` instead of `system_type_id` in catalog.sql, preventing them from appearing as `SYSNAME` in `dbt docs`. [#637](https://github.com/dbt-msft/dbt-sqlserver/issues/637) +- Fix `is_numeric()` to exclude `money`/`smallmoney` (now `is_fixed_numeric()`), preventing incorrect type expansion for fixed-precision money types. +- Fix seed table ingestion of empty numeric cells by inlining `null` literals instead of binding parameters. [#425](https://github.com/dbt-msft/dbt-sqlserver/issues/425) +- Fix integer-to-numeric safe expansion to require sufficient precision (e.g. `int` → `numeric(10,0)` minimum), avoiding data-loss risk. + +#### Migration note + +- `money` and `smallmoney` columns are no longer classified as `is_numeric()`. If you have custom code or macros that depend on `money` being numeric, use `is_number()` (which covers all numeric types) or `is_fixed_numeric()` for money types specifically. + ### v1.10.0 #### Features diff --git a/README.md b/README.md index 3c431ec0a..45636a185 100644 --- a/README.md +++ b/README.md @@ -148,6 +148,44 @@ vars: *(default: `pyodbc`)* Set to `mssql-python` in a profile target to use the `mssql-python` backend instead of `pyodbc`. The adapter fails if the required backend package (Python dependency), such as `pyodbc` or `mssql-python`, is not installed. +### `dbt_sqlserver_enable_safe_type_expansion` + +*(default: `false`)* When enabled, allows the adapter to widen column types during incremental model schema expansion beyond same-family string resizes. Supported safe expansions include: + +- **Cross-family string**: `varchar`/`char` → `nvarchar`/`nchar` (same or larger size) +- **Integer family**: `bit` → `tinyint` → `smallint` → `int` → `bigint` +- **Integer → numeric**: `int` → `numeric` (with sufficient precision to hold the integer range) +- **Numeric precision/scale**: `numeric(p,s)` → `numeric(p2,s2)` where precision and scale both increase +- **Fixed-money**: `smallmoney` → `money`, `money` → `numeric` (with sufficient precision) + +Safe expansions are further gated by `column_type_expansion_max_rows` (default 1,000,000 rows) to avoid long-running operations on large tables. + +```yaml +# dbt_project.yml +flags: + dbt_sqlserver_enable_safe_type_expansion: true +``` + +### `column_type_expansion_max_rows` + +*(default: `1000000`)* Per-model config that limits when safe type expansion runs. When the target table exceeds this row count, safe type expansion is skipped (basic same-family string resizes still proceed). Set to `-1` to disable the check entirely. + +```sql +-- In an incremental model +{{ config(materialized='incremental', unique_key='id', + column_type_expansion_max_rows=500000) }} +``` + +### `prefer_single_alter_column` + +*(default: `false`)* Model-level config that controls how `alter_column_type` changes column types on tables. When `false` (default), the adapter uses the safer approach: add a temporary column, copy data, drop the original, and rename. When `true`, the adapter uses a single `ALTER COLUMN` statement, which is faster on small, medium tables and instant on safe type expansions but may fail for types that cannot be implicitly converted. + +```sql +-- In an incremental model +{{ config(materialized='incremental', unique_key='id', + prefer_single_alter_column=true) }} +``` + ## Contributing [![Unit tests](https://github.com/dbt-msft/dbt-sqlserver/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/dbt-msft/dbt-sqlserver/actions/workflows/unit-tests.yml) diff --git a/dbt/adapters/sqlserver/sqlserver_adapter.py b/dbt/adapters/sqlserver/sqlserver_adapter.py index a575cf42c..152c8eaa3 100644 --- a/dbt/adapters/sqlserver/sqlserver_adapter.py +++ b/dbt/adapters/sqlserver/sqlserver_adapter.py @@ -15,7 +15,8 @@ from dbt.adapters.base.meta import available from dbt.adapters.base.relation import BaseRelation from dbt.adapters.capability import Capability, CapabilityDict, CapabilitySupport, Support -from dbt.adapters.events.types import SchemaCreation +from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.events.types import ColTypeChange, SchemaCreation from dbt.adapters.reference_keys import _make_ref_key_dict from dbt.adapters.sql.impl import CREATE_SCHEMA_MACRO_NAME, SQLAdapter from dbt.adapters.sqlserver.sqlserver_column import SQLServerColumn, SQLServerColumnNative @@ -23,6 +24,8 @@ from dbt.adapters.sqlserver.sqlserver_connections import SQLServerConnectionManager from dbt.adapters.sqlserver.sqlserver_relation import SQLServerRelation +logger = AdapterLogger("SQLServer") + class SQLServerAdapter(SQLAdapter): """ @@ -99,6 +102,16 @@ def _behavior_flags(self) -> List[BehaviorFlag]: "The new behaviour is intended to become the default in a future release." ), }, + { + "name": "dbt_sqlserver_enable_safe_type_expansion", + "default": False, + "description": ( + "Allow the SQL Server adapter to widen column types during schema expansion. " + "This enables promotions like varchar -> nvarchar, " + "bit -> tinyint -> smallint -> int -> bigint, " + "and numeric(p,s) -> numeric(p2,s2) using alter column." + ), + }, ] @available.parse(lambda *a, **k: []) @@ -288,6 +301,109 @@ def render_model_constraint(cls, constraint: ModelLevelConstraint) -> Optional[s else: return None + def _get_row_count(self, relation) -> int: + """Return the number of rows in the given relation.""" + sql = f"SELECT COUNT_BIG(*) FROM {relation}" + _, cursor = self.connections.add_select_query(sql) + row = cursor.fetchone() + return int(row[0]) if row else 0 + + def expand_column_types(self, goal, current, max_rows: int = 1000000): + """Override to ensure we preserve nvarchar/nchar type family during + column expansion. Necessary same-family resizes (e.g. varchar size) + always proceed. Safe type expansions (cross-family promotions like + varchar -> nvarchar) are guarded by column_type_expansion_max_rows. + enable_safe_type_expansion is the future approach for widening.""" + + reference_columns = {c.name: c for c in self.get_columns_in_relation(goal)} + target_columns = {c.name: c for c in self.get_columns_in_relation(current)} + + enable_safe = self.behavior.dbt_sqlserver_enable_safe_type_expansion + + row_count_exceeds = False + if enable_safe and max_rows != -1: + if max_rows == 0: + row_count_exceeds = True + logger.info( + "Safe type expansion skipped for %s: " "column_type_expansion_max_rows is 0.", + current, + ) + else: + row_count = self._get_row_count(current) + if row_count > max_rows: + row_count_exceeds = True + logger.warning( + "Safe type expansion skipped for %s: " + "%s rows exceeds column_type_expansion_max_rows (%s). " + "Set column_type_expansion_max_rows=-1 to disable " + "this check, or increase the limit.", + current, + row_count, + max_rows, + ) + + for column_name, reference_column in reference_columns.items(): + target_column = target_columns.get(column_name) + if target_column is None: + continue + + if target_column.can_expand_to(reference_column): + pass + elif ( + enable_safe + and not row_count_exceeds + and target_column.can_expand_safe(reference_column) + ): + pass + else: + continue + + if reference_column.is_string(): + col_string_size = reference_column.string_size() + new_type = reference_column.string_type_instance(col_string_size) + else: + new_type = reference_column.data_type + fire_event( + ColTypeChange( + orig_type=target_column.data_type, + new_type=new_type, + table=_make_ref_key_dict(current), + ) + ) + self.alter_column_type(current, column_name, new_type) + + @available.parse_none + def expand_target_column_types( + self, from_relation: BaseRelation, to_relation: BaseRelation, max_rows: int = 1000000 + ) -> None: + if not isinstance(from_relation, self.Relation): + from dbt.adapters.base.impl import MacroArgTypeError + + raise MacroArgTypeError( + method_name="expand_target_column_types", + arg_name="from_relation", + got_value=from_relation, + expected_type=self.Relation, + ) + if not isinstance(to_relation, self.Relation): + from dbt.adapters.base.impl import MacroArgTypeError + + raise MacroArgTypeError( + method_name="expand_target_column_types", + arg_name="to_relation", + got_value=to_relation, + expected_type=self.Relation, + ) + self.expand_column_types(from_relation, to_relation, max_rows) + + def alter_column_type(self, relation, column_name, new_column_type): + kwargs = { + "relation": relation, + "column_name": column_name, + "new_column_type": new_column_type, + } + self.execute_macro("alter_column_type", kwargs=kwargs) + COLUMNS_EQUAL_SQL = """ with diff_count as ( diff --git a/dbt/adapters/sqlserver/sqlserver_column.py b/dbt/adapters/sqlserver/sqlserver_column.py index d93281b5f..10bb66539 100644 --- a/dbt/adapters/sqlserver/sqlserver_column.py +++ b/dbt/adapters/sqlserver/sqlserver_column.py @@ -37,6 +37,23 @@ class SQLServerColumn(Column): @classmethod def string_type(cls, size: int) -> str: + """Class-level string_type used by SQLAdapter.expand_column_types. + + Return a VARCHAR default for the SQLAdapter path; this keeps behaviour + consistent with the rest of dbt where class-level string_type is + generic and not instance-aware. + """ + return f"varchar({size if size > 0 else '8000'})" + + def string_type_instance(self, size: int) -> str: + """Instance-level string type selection that respects NVARCHAR/NCHAR.""" + dtype = (self.dtype or "").lower() + if dtype == "nvarchar": + return f"nvarchar({size if size > 0 else '4000'})" + if dtype == "nchar": + return f"nchar({size if size > 0 else '1'})" + if dtype == "char": + return f"char({size if size > 0 else '1'})" return f"varchar({size if size > 0 else '8000'})" def literal(self, value: Any) -> str: @@ -48,31 +65,31 @@ def data_type(self) -> str: if self.dtype.lower() == "datetime2": return "datetime2(6)" if self.is_string(): - return self.string_type(self.string_size()) + return self.string_type_instance(self.string_size()) elif self.is_numeric(): return self.numeric_type(self.dtype, self.numeric_precision, self.numeric_scale) else: return self.dtype def is_string(self) -> bool: - return self.dtype.lower() in ["varchar", "char"] + return self.dtype.lower() in ["varchar", "char", "nvarchar", "nchar"] def is_number(self): - return any([self.is_integer(), self.is_numeric(), self.is_float()]) + return any( + [self.is_integer(), self.is_numeric(), self.is_float(), self.is_fixed_numeric()] + ) def is_float(self): return self.dtype.lower() in ["float", "real"] def is_integer(self) -> bool: return self.dtype.lower() in [ - # real types "smallint", "integer", "bigint", "smallserial", "serial", "bigserial", - # aliases "int2", "int4", "int8", @@ -80,10 +97,15 @@ def is_integer(self) -> bool: "serial4", "serial8", "int", + "tinyint", + "bit", ] def is_numeric(self) -> bool: - return self.dtype.lower() in ["numeric", "decimal", "money", "smallmoney"] + return self.dtype.lower() in ["numeric", "decimal"] + + def is_fixed_numeric(self) -> bool: + return self.dtype.lower() in ["money", "smallmoney"] def string_size(self) -> int: if not self.is_string(): @@ -93,10 +115,64 @@ def string_size(self) -> int: else: return int(self.char_size) - def can_expand_to(self, other_column: "SQLServerColumn") -> bool: - if not self.is_string() or not other_column.is_string(): + def can_expand_to(self, other_column: "Column") -> bool: + self_dtype = self.dtype.lower() + other_dtype = other_column.dtype.lower() + if self.is_string() and other_column.is_string(): + self_size = self.string_size() + other_size = other_column.string_size() + if other_size > self_size and self_dtype == other_dtype: + return True + return False + + def can_expand_safe(self, other_column: "SQLServerColumn") -> bool: + self_dtype = self.dtype.lower() + other_dtype = other_column.dtype.lower() + + if self.is_string() and other_column.is_string(): + self_size = self.string_size() + other_size = other_column.string_size() + if self_dtype in ("varchar", "char") and other_dtype in ("nvarchar", "nchar"): + return other_size >= self_size return False - return other_column.string_size() > self.string_size() + + if not self.is_number() or not other_column.is_number(): + return False + + int_family = ("bit", "tinyint", "smallint", "int", "bigint") + if self_dtype in int_family and other_dtype in int_family: + return int_family.index(other_dtype) > int_family.index(self_dtype) + + self_prec = int(self.numeric_precision or 0) + other_prec = int(other_column.numeric_precision or 0) + + if self.is_integer() and other_column.is_numeric(): + minimum_int_precision: int + if self_dtype in ("tinyint",): + minimum_int_precision = 3 + elif self_dtype in ("smallint", "int2"): + minimum_int_precision = 5 + elif self_dtype in ("bigint", "int8", "bigserial", "serial8"): + minimum_int_precision = 19 + elif self_dtype in ("bit",): + minimum_int_precision = 1 + else: + minimum_int_precision = 10 + effective_self_prec = max(self_prec, minimum_int_precision) + if other_prec >= effective_self_prec: + return True + + if (self.is_numeric() or self.is_fixed_numeric()) and ( + other_column.is_numeric() or other_column.is_fixed_numeric() + ): + self_scale = int(self.numeric_scale or 0) + other_scale = int(other_column.numeric_scale or 0) + + if other_prec >= self_prec and other_scale >= self_scale: + if other_prec > self_prec or other_scale > self_scale or self_dtype != other_dtype: + return True + + return False class SQLServerColumnNative(SQLServerColumn): diff --git a/dbt/adapters/sqlserver/sqlserver_configs.py b/dbt/adapters/sqlserver/sqlserver_configs.py index bf6d2d1e2..ca125705a 100644 --- a/dbt/adapters/sqlserver/sqlserver_configs.py +++ b/dbt/adapters/sqlserver/sqlserver_configs.py @@ -7,3 +7,5 @@ @dataclass class SQLServerConfigs(AdapterConfig): auto_provision_aad_principals: Optional[bool] = False + prefer_single_alter_column: Optional[bool] = False + column_type_expansion_max_rows: Optional[int] = None diff --git a/dbt/include/sqlserver/macros/adapters/catalog.sql b/dbt/include/sqlserver/macros/adapters/catalog.sql index 8e4b5c161..03223f70c 100644 --- a/dbt/include/sqlserver/macros/adapters/catalog.sql +++ b/dbt/include/sqlserver/macros/adapters/catalog.sql @@ -96,7 +96,7 @@ c.column_id as column_index, t.name as column_type from sys.columns as c {{ information_schema_hints() }} - left join sys.types as t {{ information_schema_hints() }} on c.system_type_id = t.system_type_id + left join sys.types as t {{ information_schema_hints() }} on c.user_type_id = t.user_type_id ) select @@ -226,7 +226,7 @@ c.column_id as column_index, t.name as column_type from sys.columns as c {{ information_schema_hints() }} - left join sys.types as t on c.system_type_id = t.system_type_id + left join sys.types as t on c.user_type_id = t.user_type_id ) select diff --git a/dbt/include/sqlserver/macros/adapters/columns.sql b/dbt/include/sqlserver/macros/adapters/columns.sql index a9bc4bfe6..626128278 100644 --- a/dbt/include/sqlserver/macros/adapters/columns.sql +++ b/dbt/include/sqlserver/macros/adapters/columns.sql @@ -26,27 +26,38 @@ {% macro sqlserver__alter_column_type(relation, column_name, new_column_type) %} - {%- set tmp_column = column_name + "__dbt_alter" -%} - {% set alter_column_type %} - alter {{ relation.type }} {{ relation }} add "{{ tmp_column }}" {{ new_column_type }}; - {%- endset %} + {% set prefer_single = config.get('prefer_single_alter_column', false) %} - {% set update_column %} - update {{ relation }} set "{{ tmp_column }}" = "{{ column_name }}"; - {%- endset %} + {% if prefer_single and relation.type == 'table' %} + {% set alter_sql %} + alter {{ relation.type }} {{ relation }} + alter column "{{ column_name }}" {{ new_column_type }}; + {%- endset %} + {% do run_query(alter_sql) %} - {% set drop_column %} - alter {{ relation.type }} {{ relation }} drop column "{{ column_name }}"; - {%- endset %} + {% else %} + {%- set tmp_column = column_name + "__dbt_alter" -%} - {% set rename_column %} - exec sp_rename '{{ relation | replace('"', '') }}.{{ tmp_column }}', '{{ column_name }}', 'column' - {%- endset %} + {% set add_column %} + alter {{ relation.type }} {{ relation }} + add "{{ tmp_column }}" {{ new_column_type }}; + {%- endset %} + {% set update_column %} + update {{ relation }} set "{{ tmp_column }}" = "{{ column_name }}"; + {%- endset %} + {% set drop_column %} + alter {{ relation.type }} {{ relation }} + drop column "{{ column_name }}"; + {%- endset %} + {% set rename_column %} + exec sp_rename '{{ relation | replace('"', '') }}.{{ tmp_column }}', '{{ column_name }}', 'column' + {%- endset %} - {% do run_query(alter_column_type) %} - {% do run_query(update_column) %} - {% do run_query(drop_column) %} - {% do run_query(rename_column) %} + {% do run_query(add_column) %} + {% do run_query(update_column) %} + {% do run_query(drop_column) %} + {% do run_query(rename_column) %} + {% endif %} {% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/models/incremental/incremental.sql b/dbt/include/sqlserver/macros/materializations/models/incremental/incremental.sql index a70a0f658..ecd016e5c 100644 --- a/dbt/include/sqlserver/macros/materializations/models/incremental/incremental.sql +++ b/dbt/include/sqlserver/macros/materializations/models/incremental/incremental.sql @@ -42,9 +42,11 @@ {% set contract_config = config.get('contract') %} {% if not contract_config or not contract_config.enforced %} + {% set expansion_max_rows = config.get('column_type_expansion_max_rows', 1000000) %} {% do adapter.expand_target_column_types( from_relation=temp_relation, - to_relation=target_relation) %} + to_relation=target_relation, + max_rows=expansion_max_rows) %} {% endif %} {#-- Process schema changes. Returns dict of changes if successful. Use source columns for upserting/merging --#} {% set dest_columns = process_schema_changes(on_schema_change, temp_relation, existing_relation) %} diff --git a/dbt/include/sqlserver/macros/materializations/snapshots/snapshot.sql b/dbt/include/sqlserver/macros/materializations/snapshots/snapshot.sql index 2f955d596..3401ec9a4 100644 --- a/dbt/include/sqlserver/macros/materializations/snapshots/snapshot.sql +++ b/dbt/include/sqlserver/macros/materializations/snapshots/snapshot.sql @@ -65,8 +65,10 @@ {% set build_or_select_sql = snapshot_staging_table(strategy, temp_snapshot_relation, target_relation) %} {% set staging_table = build_snapshot_staging_table(strategy, temp_snapshot_relation, target_relation) %} -- this may no-op if the database does not require column expansion + {% set expansion_max_rows = config.get('column_type_expansion_max_rows', 1000000) %} {% do adapter.expand_target_column_types(from_relation=staging_table, - to_relation=target_relation) %} + to_relation=target_relation, + max_rows=expansion_max_rows) %} {% set remove_columns = ['dbt_change_type', 'DBT_CHANGE_TYPE', 'dbt_unique_key', 'DBT_UNIQUE_KEY'] %} {% if unique_key | is_list %} diff --git a/dbt/include/sqlserver/macros/relations/seeds/helpers.sql b/dbt/include/sqlserver/macros/relations/seeds/helpers.sql index 34c8e726d..46b59f0a1 100644 --- a/dbt/include/sqlserver/macros/relations/seeds/helpers.sql +++ b/dbt/include/sqlserver/macros/relations/seeds/helpers.sql @@ -22,27 +22,30 @@ {% macro sqlserver__load_csv_rows(model, agate_table) %} {% set cols_sql = get_seed_column_quoted_csv(model, agate_table.column_names) %} {% set batch_size = calc_batch_size(agate_table.column_names|length) %} - {% set bindings = [] %} {% set statements = [] %} {{ log("Inserting batches of " ~ batch_size ~ " records") }} {% for chunk in agate_table.rows | batch(batch_size) %} {% set bindings = [] %} + {% set values_clause = [] %} {% for row in chunk %} - {% do bindings.extend(row) %} + {% set row_values = [] %} + {% for column in agate_table.column_names %} + {%- set val = row[loop.index0] -%} + {%- if val is none -%} + {%- do row_values.append("null") -%} + {%- else -%} + {%- do row_values.append(get_binding_char()) -%} + {%- do bindings.append(val) -%} + {%- endif -%} + {% endfor %} + {% do values_clause.append("(" ~ row_values | join(", ") ~ ")") %} {% endfor %} {% set sql %} - insert into {{ this.render() }} ({{ cols_sql }}) values - {% for row in chunk -%} - ({%- for column in agate_table.column_names -%} - {{ get_binding_char() }} - {%- if not loop.last%},{%- endif %} - {%- endfor -%}) - {%- if not loop.last%},{%- endif %} - {%- endfor %} + insert into {{ this.render() }} ({{ cols_sql }}) values {{ values_clause | join(", ") }} {% endset %} {% do adapter.add_query(sql, bindings=bindings, abridge_sql_log=True) %} diff --git a/pyproject.toml b/pyproject.toml index f614941a7..e2130542d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,3 +90,6 @@ namespaces = true [tool.uv] link-mode = "copy" + +[tool.ty.environment] +python = ".venv" diff --git a/tests/functional/adapter/dbt/test_catalog.py b/tests/functional/adapter/dbt/test_catalog.py index da0924334..3ff431d8b 100644 --- a/tests/functional/adapter/dbt/test_catalog.py +++ b/tests/functional/adapter/dbt/test_catalog.py @@ -154,3 +154,49 @@ def test_docs_generate_includes_non_default_database(self, project): assert "id" in other_node["columns"] finally: self.cleanup_secondary_database(project) + + +CATALOG_COLUMN_TYPES_SQL = """ +{{ config(materialized='table') }} +select + cast('hello' as nvarchar(50)) as nv_col, + cast('h' as nchar(1)) as nc_col, + cast(1 as int) as int_col +""" + + +class TestCatalogColumnTypes: + """ + This test addresses: https://github.com/dbt-msft/dbt-sqlserver/issues/637 + catalog.sql used system_type_id instead of user_type_id causing + NVARCHAR/NCHAR columns to appear as SYSNAME in dbt docs. + """ + + @pytest.fixture(scope="class") + def project_config_update(self): + return {"name": "catalog_column_types_test"} + + @pytest.fixture(scope="class") + def models(self): + return {"catalog_model.sql": CATALOG_COLUMN_TYPES_SQL} + + @pytest.fixture(scope="class") + def docs(self, project): + run_dbt(["run"]) + yield run_dbt(["docs", "generate"]) + + def test_catalog_does_not_return_sysname(self, project, docs): + catalog_path = os.path.join(project.project_root, "target", "catalog.json") + with open(catalog_path) as f: + catalog = json.load(f) + + nodes = catalog.get("nodes", {}) + for node_name, node in nodes.items(): + if "catalog_model" not in node_name: + continue + for col_name, col in node.get("columns", {}).items(): + col_type = col.get("type", "").lower() + assert "sysname" not in col_type, ( + f"Column '{col_name}' has type '{col_type}' " + f"which contains SYSNAME instead of NVARCHAR/NCHAR" + ) diff --git a/tests/functional/adapter/mssql/test_column_type_expansion.py b/tests/functional/adapter/mssql/test_column_type_expansion.py new file mode 100644 index 000000000..745acc89c --- /dev/null +++ b/tests/functional/adapter/mssql/test_column_type_expansion.py @@ -0,0 +1,274 @@ +"""Functional tests for column type expansion and addition +via the incremental materialization. + +Two scenarios tested with default and native string type flags: + 1. Column type expansion via expand_target_column_types + 2. Adding a new nvarchar column via on_schema_change (append / sync_all) +""" + +import os + +import pytest + +from dbt.tests.util import run_dbt + + +def _column_type(project, schema, table, column): + rows = project.run_sql( + f""" + select t.name, c.max_length + from [{project.database}].sys.columns c + inner join [{project.database}].sys.types t + on c.user_type_id = t.user_type_id + where c.object_id = object_id('[{project.database}].[{schema}].[{table}]') + and c.name = '{column}' + """, + fetch="all", + ) + if not rows: + return None + dtype, max_length = rows[0] + if dtype in ("nchar", "nvarchar", "sysname") and max_length != -1: + return (dtype, max_length // 2) + return (dtype, max_length) + + +def write_model(project, filename, contents): + path = os.path.join(project.project_root, "models", filename) + with open(path, "w") as f: + f.write(contents) + + +# --- Model SQL for expansion test --- + +EXPAND_V1 = """ +{{ config(materialized='incremental', unique_key='id') }} +select 1 as id, cast('hello' as varchar(10)) as str_col +""" + +EXPAND_V2 = """ +{{ config(materialized='incremental', unique_key='id') }} +select 1 as id, cast('hello world' as varchar(25)) as str_col +""" + +# --- Model SQL for add-column test --- + +ADD_COL_V1 = """ +{{ + config(materialized='incremental', unique_key='id', + on_schema_change='append_new_columns') +}} +select 1 as id, cast('hello' as varchar(10)) as str_col +""" + +ADD_COL_V2 = """ +{{ + config(materialized='incremental', unique_key='id', + on_schema_change='append_new_columns') +}} +select 1 as id, + cast('hello' as varchar(10)) as str_col, + cast('hello' as nvarchar(20)) as new_col +""" + +# --- Model SQL for sync-all-columns test --- + +SYNC_V1 = """ +{{ + config(materialized='incremental', unique_key='id', + on_schema_change='sync_all_columns') +}} +select 1 as id, cast('hello' as varchar(10)) as str_col +""" + +SYNC_V2 = """ +{{ + config(materialized='incremental', unique_key='id', + on_schema_change='sync_all_columns') +}} +select 1 as id, + cast('hello world' as varchar(25)) as str_col, + cast('hello' as nvarchar(20)) as new_col +""" + + +# ============================================================================ +# Default string types (dbt_sqlserver_use_native_string_types = false) +# ============================================================================ + + +class TestExpansionDefault: + @pytest.fixture(scope="class") + def models(self): + return {"expand_test.sql": EXPAND_V1} + + def test_varchar_size_expansion(self, project): + run_dbt(["run", "--full-refresh"]) + write_model(project, "expand_test.sql", EXPAND_V2) + results = run_dbt(["run"]) + assert len(results) == 1 + assert results[0].status == "success" + + typ = _column_type(project, project.test_schema, "expand_test", "str_col") + assert typ == ("varchar", 25), f"Expected varchar(25), got {typ}" + + +class TestAddColumnDefault: + """ + This test addresses: https://github.com/dbt-msft/dbt-sqlserver/issues/446 + """ + + @pytest.fixture(scope="class") + def models(self): + return {"add_col_test.sql": ADD_COL_V1} + + def test_add_nvarchar_column(self, project): + run_dbt(["run", "--full-refresh"]) + write_model(project, "add_col_test.sql", ADD_COL_V2) + results = run_dbt(["run"]) + assert len(results) == 1 + assert results[0].status == "success" + + typ = _column_type(project, project.test_schema, "add_col_test", "new_col") + assert typ == ("nvarchar", 20), f"Expected nvarchar(20), got {typ}" + + +class TestSyncColumnsDefault: + @pytest.fixture(scope="class") + def models(self): + return {"sync_test.sql": SYNC_V1} + + def test_sync_all_columns(self, project): + run_dbt(["run", "--full-refresh"]) + write_model(project, "sync_test.sql", SYNC_V2) + results = run_dbt(["run"]) + assert len(results) == 1 + assert results[0].status == "success" + + typ = _column_type(project, project.test_schema, "sync_test", "str_col") + assert typ == ("varchar", 25), f"Expected varchar(25), got {typ}" + + typ = _column_type(project, project.test_schema, "sync_test", "new_col") + assert typ == ("nvarchar", 20), f"Expected nvarchar(20), got {typ}" + + +# ============================================================================ +# Safe type expansion: varchar -> nvarchar (requires enable_safe flag) +# ============================================================================ + + +NVARCHAR_V1 = """ +{{ config(materialized='incremental', unique_key='id', + column_type_expansion_max_rows=10) }} +select 1 as id, cast('hello' as varchar(10)) as str_col +""" + +NVARCHAR_V2 = """ +{{ config(materialized='incremental', unique_key='id', + column_type_expansion_max_rows=10) }} +select 1 as id, cast('hi' as nvarchar(25)) as str_col +""" + + +class TestVarcharToNvarcharWithoutFlag: + @pytest.fixture(scope="class") + def models(self): + return {"nvarchar_test.sql": NVARCHAR_V1} + + def test_varchar_to_nvarchar_blocked_without_flag(self, project): + run_dbt(["run", "--full-refresh"]) + write_model(project, "nvarchar_test.sql", NVARCHAR_V2) + results = run_dbt(["run"]) + assert len(results) == 1 + assert results[0].status == "success" + + typ = _column_type(project, project.test_schema, "nvarchar_test", "str_col") + assert typ == ("varchar", 10), f"Expected varchar(10), got {typ}" + + +class TestVarcharToNvarcharWithFlag: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"dbt_sqlserver_enable_safe_type_expansion": True}} + + @pytest.fixture(scope="class") + def models(self): + return {"nvarchar_safe_test.sql": NVARCHAR_V1} + + def test_varchar_to_nvarchar_works_with_flag(self, project): + run_dbt(["run", "--full-refresh"]) + write_model(project, "nvarchar_safe_test.sql", NVARCHAR_V2) + results = run_dbt(["run"]) + assert len(results) == 1 + assert results[0].status == "success" + + typ = _column_type(project, project.test_schema, "nvarchar_safe_test", "str_col") + assert typ == ("nvarchar", 25), f"Expected nvarchar(25), got {typ}" + + +# ============================================================================ +# Native string types (dbt_sqlserver_use_native_string_types = true) +# ============================================================================ + + +class TestExpansionNative: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"dbt_sqlserver_use_native_string_types": True}} + + @pytest.fixture(scope="class") + def models(self): + return {"expand_test.sql": EXPAND_V1} + + def test_varchar_size_expansion_native(self, project): + run_dbt(["run", "--full-refresh"]) + write_model(project, "expand_test.sql", EXPAND_V2) + results = run_dbt(["run"]) + assert len(results) == 1 + assert results[0].status == "success" + + typ = _column_type(project, project.test_schema, "expand_test", "str_col") + assert typ == ("varchar", 25), f"Expected varchar(25), got {typ}" + + +class TestAddColumnNative: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"dbt_sqlserver_use_native_string_types": True}} + + @pytest.fixture(scope="class") + def models(self): + return {"add_col_test.sql": ADD_COL_V1} + + def test_add_nvarchar_column_native(self, project): + run_dbt(["run", "--full-refresh"]) + write_model(project, "add_col_test.sql", ADD_COL_V2) + results = run_dbt(["run"]) + assert len(results) == 1 + assert results[0].status == "success" + + typ = _column_type(project, project.test_schema, "add_col_test", "new_col") + assert typ == ("nvarchar", 20), f"Expected nvarchar(20), got {typ}" + + +class TestSyncColumnsNative: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"dbt_sqlserver_use_native_string_types": True}} + + @pytest.fixture(scope="class") + def models(self): + return {"sync_test.sql": SYNC_V1} + + def test_sync_all_columns_native(self, project): + run_dbt(["run", "--full-refresh"]) + write_model(project, "sync_test.sql", SYNC_V2) + results = run_dbt(["run"]) + assert len(results) == 1 + assert results[0].status == "success" + + typ = _column_type(project, project.test_schema, "sync_test", "str_col") + assert typ == ("varchar", 25), f"Expected varchar(25), got {typ}" + + typ = _column_type(project, project.test_schema, "sync_test", "new_col") + assert typ == ("nvarchar", 20), f"Expected nvarchar(20), got {typ}" diff --git a/tests/functional/adapter/mssql/test_mssql_seed.py b/tests/functional/adapter/mssql/test_mssql_seed.py index 8a8cdfe22..a2b21222a 100644 --- a/tests/functional/adapter/mssql/test_mssql_seed.py +++ b/tests/functional/adapter/mssql/test_mssql_seed.py @@ -40,3 +40,41 @@ def seeds(self): def test_large_seed(self, project): run_dbt(["seed"]) + + +seed_empty_numeric_csv = """x +123 + +456 +""" + +seed_empty_numeric_yml = """ +version: 2 +seeds: + - name: seed_empty_numeric + config: + column_types: + x: numeric(18, 0) +""" + + +class TestSeedNumericColumnWithEmptyRows: + """ + This test addresses: https://github.com/dbt-msft/dbt-sqlserver/issues/425 + """ + + @pytest.fixture(scope="class") + def project_config_update(self): + return {"name": "seed_empty_numeric_test"} + + @pytest.fixture(scope="class") + def seeds(self): + return { + "seed_empty_numeric.csv": seed_empty_numeric_csv, + "schema.yml": seed_empty_numeric_yml, + } + + def test_seed_numeric_column_with_empty_rows(self, project): + results = run_dbt(["seed"]) + assert len(results) == 1 + assert results[0].status == "success" diff --git a/tests/unit/adapters/mssql/test_can_expand_to.py b/tests/unit/adapters/mssql/test_can_expand_to.py new file mode 100644 index 000000000..bc63a09ec --- /dev/null +++ b/tests/unit/adapters/mssql/test_can_expand_to.py @@ -0,0 +1,85 @@ +import pytest + +from dbt.adapters.sqlserver.sqlserver_column import SQLServerColumn + + +def col_kwargs(dtype, char_size=None, numeric_precision=0, numeric_scale=0): + return { + "column": "c", + "dtype": dtype, + "char_size": char_size, + "numeric_precision": numeric_precision, + "numeric_scale": numeric_scale, + } + + +@pytest.mark.parametrize( + "src_kwargs,tgt_kwargs,expect_with_flag,expect_without_flag", + [ + # String same-family expansions always work + (col_kwargs("varchar", char_size=10), col_kwargs("varchar", char_size=100), True, True), + (col_kwargs("char", char_size=5), col_kwargs("char", char_size=20), True, True), + (col_kwargs("nvarchar", char_size=50), col_kwargs("nvarchar", char_size=200), True, True), + (col_kwargs("nchar", char_size=10), col_kwargs("nchar", char_size=30), True, True), + # String same-size does not expand + (col_kwargs("varchar", char_size=100), col_kwargs("varchar", char_size=100), False, False), + # String smaller target does not expand + (col_kwargs("varchar", char_size=100), col_kwargs("varchar", char_size=50), False, False), + # String cross-family (VARCHAR -> NVARCHAR) requires flag + (col_kwargs("varchar", char_size=10), col_kwargs("nvarchar", char_size=10), True, False), + (col_kwargs("char", char_size=5), col_kwargs("nchar", char_size=5), True, False), + # String cross-family reverse (NVARCHAR -> VARCHAR) never works + (col_kwargs("nvarchar", char_size=10), col_kwargs("varchar", char_size=10), False, False), + # Integer family promotions require the feature flag + (col_kwargs("int"), col_kwargs("bigint"), True, False), + (col_kwargs("bit"), col_kwargs("tinyint"), True, False), + # Integer -> numeric widening requires the feature flag + (col_kwargs("int"), col_kwargs("numeric", numeric_precision=10), True, False), + # Numeric/decimal promotions: precision/scale must increase; flag required + ( + col_kwargs("numeric", numeric_precision=10, numeric_scale=2), + col_kwargs("numeric", numeric_precision=12, numeric_scale=4), + True, + False, + ), + ( + col_kwargs("numeric", numeric_precision=10, numeric_scale=2), + col_kwargs("numeric", numeric_precision=12, numeric_scale=1), + False, + False, + ), + # Fixed-money types (MONEY/SMALLMONEY) + ( + col_kwargs("smallmoney", numeric_precision=10, numeric_scale=4), + col_kwargs("money", numeric_precision=19, numeric_scale=4), + True, + False, + ), + ( + col_kwargs("money", numeric_precision=19, numeric_scale=4), + col_kwargs("numeric", numeric_precision=20, numeric_scale=4), + True, + False, + ), + # MONEY -> NUMERIC with dtype change and equal specs + ( + col_kwargs("money", numeric_precision=19, numeric_scale=4), + col_kwargs("numeric", numeric_precision=19, numeric_scale=4), + True, + False, + ), + # NUMERIC -> MONEY that would shrink precision should not be allowed + ( + col_kwargs("numeric", numeric_precision=20, numeric_scale=4), + col_kwargs("money", numeric_precision=19, numeric_scale=4), + False, + False, + ), + ], +) +def test_can_expand_parametrized(src_kwargs, tgt_kwargs, expect_with_flag, expect_without_flag): + src = SQLServerColumn(**src_kwargs) + tgt = SQLServerColumn(**tgt_kwargs) + + assert src.can_expand_to(tgt) is expect_without_flag + assert (src.can_expand_to(tgt) or src.can_expand_safe(tgt)) is expect_with_flag diff --git a/tests/unit/adapters/mssql/test_expand_column_types.py b/tests/unit/adapters/mssql/test_expand_column_types.py new file mode 100644 index 000000000..7f0edd230 --- /dev/null +++ b/tests/unit/adapters/mssql/test_expand_column_types.py @@ -0,0 +1,94 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from dbt.adapters.sqlserver.sqlserver_adapter import SQLServerAdapter +from dbt.adapters.sqlserver.sqlserver_relation import SQLServerRelation + + +@pytest.fixture +def adapter(): + config = MagicMock() + config.flags = {} + config.project_name = "test" + config.credentials.type = "sqlserver" + mp_context = MagicMock() + adapter = SQLServerAdapter(config, mp_context) + adapter._get_row_count = MagicMock(return_value=0) + adapter.get_columns_in_relation = MagicMock(return_value=[]) + adapter.alter_column_type = MagicMock() + adapter.behavior = MagicMock() + adapter.behavior.dbt_sqlserver_enable_safe_type_expansion = True + return adapter + + +def make_rel(name="t"): + rel = MagicMock(spec=SQLServerRelation) + rel.__str__ = lambda s: f"test_schema.{name}" + return rel + + +class TestExpandColumnTypes: + def test_skips_row_count_when_max_rows_is_negative_one(self, adapter): + adapter.expand_column_types(make_rel("goal"), make_rel("current"), max_rows=-1) + adapter._get_row_count.assert_not_called() + + def test_blocks_safe_expansion_when_max_rows_is_zero(self, adapter): + adapter._get_row_count.return_value = 0 + adapter.get_columns_in_relation = MagicMock(return_value=[]) + adapter.alter_column_type = MagicMock() + + goal = make_rel("goal") + goal_col = MagicMock() + goal_col.name = "c" + goal_col.dtype = "nvarchar" + goal_col.is_string = MagicMock(return_value=True) + goal_col.is_number = MagicMock(return_value=True) + goal_col.string_size = MagicMock(return_value=20) + goal_col.string_type_instance = MagicMock(return_value="nvarchar(20)") + goal_col.data_type = "nvarchar(20)" + + current = make_rel("current") + current_col = MagicMock() + current_col.name = "c" + current_col.dtype = "varchar" + current_col.is_string = MagicMock(return_value=True) + current_col.is_number = MagicMock(return_value=True) + current_col.can_expand_to = MagicMock(return_value=False) + current_col.can_expand_safe = MagicMock(return_value=True) + + adapter.get_columns_in_relation.side_effect = lambda r: ( + [goal_col] if r is goal else [current_col] + ) + + with patch("dbt.adapters.sqlserver.sqlserver_adapter.logger"): + adapter.expand_column_types(goal, current, max_rows=0) + + adapter._get_row_count.assert_not_called() + adapter.alter_column_type.assert_not_called() + + def test_reads_row_count_when_within_limit(self, adapter): + adapter._get_row_count.return_value = 50 + adapter.expand_column_types(make_rel("goal"), make_rel("current"), max_rows=100) + adapter._get_row_count.assert_called_once() + + def test_emits_warning_when_row_count_exceeds_max(self, adapter): + adapter._get_row_count.return_value = 200 + with patch("dbt.adapters.sqlserver.sqlserver_adapter.logger") as logger: + adapter.expand_column_types(make_rel("goal"), make_rel("current"), max_rows=100) + adapter._get_row_count.assert_called_once() + logger.warning.assert_called_once() + + def test_expand_target_column_types_forwards_max_rows(self, adapter): + adapter._get_row_count.return_value = 0 + adapter.get_columns_in_relation = MagicMock(return_value=[]) + adapter.alter_column_type = MagicMock() + + goal = make_rel("goal") + current = make_rel("current") + max_rows = 500 + + with patch.object(adapter, "expand_column_types") as mock_expand: + adapter.expand_target_column_types(goal, current, max_rows=max_rows) + + mock_expand.assert_called_once_with(goal, current, max_rows) diff --git a/tests/unit/adapters/mssql/test_sqlserver_column.py b/tests/unit/adapters/mssql/test_sqlserver_column.py new file mode 100644 index 000000000..e5bd0956f --- /dev/null +++ b/tests/unit/adapters/mssql/test_sqlserver_column.py @@ -0,0 +1,123 @@ +import pytest +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.sqlserver.sqlserver_column import SQLServerColumn + + +class TestSQLServerColumnIsString: + def test_varchar_is_string(self): + col = SQLServerColumn("c", "varchar", char_size=50) + assert col.is_string() is True + + def test_char_is_string(self): + col = SQLServerColumn("c", "char", char_size=10) + assert col.is_string() is True + + def test_nvarchar_is_string(self): + col = SQLServerColumn("c", "nvarchar", char_size=100) + assert col.is_string() is True + + def test_nchar_is_string(self): + col = SQLServerColumn("c", "nchar", char_size=20) + assert col.is_string() is True + + def test_int_is_not_string(self): + col = SQLServerColumn("c", "int") + assert col.is_string() is False + + def test_numeric_is_not_string(self): + col = SQLServerColumn("c", "numeric") + assert col.is_string() is False + + +class TestSQLServerColumnStringTypeInstance: + def test_varchar_default(self): + col = SQLServerColumn("c", "varchar") + result = col.string_type_instance(100) + assert result == "varchar(100)" + + def test_varchar_max(self): + col = SQLServerColumn("c", "varchar") + result = col.string_type_instance(0) + assert result == "varchar(8000)" + + def test_nvarchar(self): + col = SQLServerColumn("c", "nvarchar") + result = col.string_type_instance(200) + assert result == "nvarchar(200)" + + def test_nvarchar_max(self): + col = SQLServerColumn("c", "nvarchar") + result = col.string_type_instance(0) + assert result == "nvarchar(4000)" + + def test_nchar(self): + col = SQLServerColumn("c", "nchar") + result = col.string_type_instance(50) + assert result == "nchar(50)" + + def test_nchar_max(self): + col = SQLServerColumn("c", "nchar") + result = col.string_type_instance(0) + assert result == "nchar(1)" + + def test_char_default(self): + col = SQLServerColumn("c", "char") + result = col.string_type_instance(5) + assert result == "char(5)" + + result = col.string_type_instance(0) + assert result == "char(1)" + + +class TestSQLServerColumnDataType: + def test_varchar_data_type(self): + col = SQLServerColumn("c", "varchar", char_size=100) + assert col.data_type == "varchar(100)" + + def test_nvarchar_data_type(self): + col = SQLServerColumn("c", "nvarchar", char_size=200) + assert col.data_type == "nvarchar(200)" + + +class TestSQLServerColumnIsFixedNumeric: + def test_money(self): + col = SQLServerColumn("c", "money") + assert col.is_fixed_numeric() is True + + def test_smallmoney(self): + col = SQLServerColumn("c", "smallmoney") + assert col.is_fixed_numeric() is True + + def test_numeric_is_not_fixed(self): + col = SQLServerColumn("c", "numeric") + assert col.is_fixed_numeric() is False + + +class TestSQLServerColumnIsNumeric: + def test_numeric(self): + col = SQLServerColumn("c", "numeric") + assert col.is_numeric() is True + + def test_decimal(self): + col = SQLServerColumn("c", "decimal") + assert col.is_numeric() is True + + def test_money_is_not_numeric(self): + col = SQLServerColumn("c", "money") + assert col.is_numeric() is False + + +class TestSQLServerColumnStringSize: + def test_string_size_with_char_size(self): + col = SQLServerColumn("c", "varchar", char_size=100) + assert col.string_size() == 100 + + def test_string_size_none_char_size(self): + col = SQLServerColumn("c", "varchar") + assert col.string_size() == 8000 + + def test_string_size_raises_on_non_string(self): + col = SQLServerColumn("c", "int") + with pytest.raises(DbtRuntimeError, match="Called string_size"): + col.string_size()