Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .circleci/continue_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,11 @@ jobs:
- halt_unless_core
- checkout
- run:
name: Run the migration test
command: ./.circleci/test_migration.sh
name: Run the migration test - sushi
command: ./.circleci/test_migration.sh sushi "--gateway duckdb_persistent"
- run:
name: Run the migration test - sushi_dbt
command: ./.circleci/test_migration.sh sushi_dbt "--config migration_test_config"

ui_style:
docker:
Expand Down
41 changes: 28 additions & 13 deletions .circleci/test_migration.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
#!/usr/bin/env bash
set -ex

GATEWAY_NAME="duckdb_persistent"
TMP_DIR=$(mktemp -d)
SUSHI_DIR="$TMP_DIR/sushi"


if [[ -z $(git tag --points-at HEAD) ]]; then
# If the current commit is not tagged, we need to find the last tag
LAST_TAG=$(git describe --tags --abbrev=0)
Expand All @@ -14,28 +9,48 @@ else
LAST_TAG=$(git tag --sort=-creatordate | head -n 2 | tail -n 1)
fi

if [ "$1" == "" ]; then
echo "Usage: $0 <example name> <sqlmesh opts>"
echo "eg $0 sushi '--gateway duckdb_persistent'"
exit 1
fi


TMP_DIR=$(mktemp -d)
EXAMPLE_NAME="$1"
SQLMESH_OPTS="$2"
EXAMPLE_DIR="./examples/$EXAMPLE_NAME"
TEST_DIR="$TMP_DIR/$EXAMPLE_NAME"

echo "Running migration test for '$EXAMPLE_NAME' in '$TEST_DIR' for example project '$EXAMPLE_DIR' using options '$SQLMESH_OPTS'"

git checkout $LAST_TAG

# Install dependencies from the previous release.
make install-dev

cp -r ./examples/sushi $TMP_DIR
cp -r $EXAMPLE_DIR $TEST_DIR

# this is only needed temporarily until the released tag for $LAST_TAG includes this config
if [ "$EXAMPLE_NAME" == "sushi_dbt" ]; then
echo 'migration_test_config = sqlmesh_config(Path(__file__).parent, dbt_target_name="duckdb")' >> $TEST_DIR/config.py
fi

# Run initial plan
pushd $SUSHI_DIR
pushd $TEST_DIR
rm -rf ./data/*
sqlmesh --gateway $GATEWAY_NAME plan --no-prompts --auto-apply
sqlmesh $SQLMESH_OPTS plan --no-prompts --auto-apply
rm -rf .cache
popd

# Switch back to the starting state of the repository
# Switch back to the starting state of the repository
git checkout -

# Install updated dependencies.
make install-dev

# Migrate and make sure the diff is empty
pushd $SUSHI_DIR
sqlmesh --gateway $GATEWAY_NAME migrate
sqlmesh --gateway $GATEWAY_NAME diff prod
popd
pushd $TEST_DIR
sqlmesh $SQLMESH_OPTS migrate
sqlmesh $SQLMESH_OPTS diff prod
popd
2 changes: 2 additions & 0 deletions examples/sushi_dbt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
config = sqlmesh_config(Path(__file__).parent)

test_config = config

migration_test_config = sqlmesh_config(Path(__file__).parent, dbt_target_name="duckdb")
2 changes: 1 addition & 1 deletion sqlmesh/cli/project_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def init_example_project(
dlt_path: t.Optional[str] = None,
schema_name: str = "sqlmesh_example",
cli_mode: InitCliMode = InitCliMode.DEFAULT,
start: t.Optional[str] = None,
) -> Path:
root_path = Path(path)

Expand Down Expand Up @@ -336,7 +337,6 @@ def init_example_project(

models: t.Set[t.Tuple[str, str]] = set()
settings = None
start = None
if engine_type and template == ProjectTemplate.DLT:
project_dialect = dialect or DIALECT_TO_TYPE.get(engine_type)
if pipeline and project_dialect:
Expand Down
10 changes: 8 additions & 2 deletions sqlmesh/core/audit/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
sorted_python_env_payloads,
)
from sqlmesh.core.model.common import make_python_env, single_value_or_tuple, ParsableSql
from sqlmesh.core.node import _Node
from sqlmesh.core.node import _Node, DbtInfoMixin, DbtNodeInfo
from sqlmesh.core.renderer import QueryRenderer
from sqlmesh.utils.date import TimeLike
from sqlmesh.utils.errors import AuditConfigError, SQLMeshError, raise_config_error
Expand Down Expand Up @@ -120,7 +120,7 @@ def audit_map_validator(cls: t.Type, v: t.Any, values: t.Any) -> t.Dict[str, t.A
return {}


class ModelAudit(PydanticModel, AuditMixin, frozen=True):
class ModelAudit(PydanticModel, AuditMixin, DbtInfoMixin, frozen=True):
"""
Audit is an assertion made about your tables.
Expand All @@ -137,6 +137,7 @@ class ModelAudit(PydanticModel, AuditMixin, frozen=True):
expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions")
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
formatting: t.Optional[bool] = Field(default=None, exclude=True)
dbt_node_info_: t.Optional[DbtNodeInfo] = Field(alias="dbt_node_info", default=None)

_path: t.Optional[Path] = None

Expand All @@ -150,6 +151,10 @@ def __str__(self) -> str:
path = f": {self._path.name}" if self._path else ""
return f"{self.__class__.__name__}<{self.name}{path}>"

@property
def dbt_node_info(self) -> t.Optional[DbtNodeInfo]:
return self.dbt_node_info_


class StandaloneAudit(_Node, AuditMixin):
"""
Expand Down Expand Up @@ -552,4 +557,5 @@ def _maybe_parse_arg_pair(e: exp.Expression) -> t.Tuple[str, exp.Expression]:
"depends_on_": lambda value: exp.Tuple(expressions=sorted(value)),
"tags": single_value_or_tuple,
"default_catalog": exp.to_identifier,
"dbt_node_info_": lambda value: value.to_expression(),
}
4 changes: 2 additions & 2 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,9 +1697,9 @@ def plan_builder(
console=self.console,
user_provided_flags=user_provided_flags,
selected_models={
dbt_name
dbt_unique_id
for model in model_selector.expand_model_selections(select_models or "*")
if (dbt_name := snapshots[model].node.dbt_name)
if (dbt_unique_id := snapshots[model].node.dbt_unique_id)
},
explain=explain or False,
ignore_cron=ignore_cron or False,
Expand Down
4 changes: 4 additions & 0 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,9 @@ def metadata_hash(self) -> str:
for k, v in sorted(args.items()):
metadata.append(f"{k}:{gen(v)}")

if self.dbt_node_info:
metadata.append(self.dbt_node_info.json(sort_keys=True))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to extend the ./.circleci/test_migration.sh test to also run on sushi_dbt project and not just sushi, since we now have dbt-specific logic when hashing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added this and I think it's failing in the way you expected it to:

**Metadata Updated:**
- `sushi.customer_revenue_by_day`
- `sushi.customers`
- `sushi.top_waiters`
- `sushi.waiter_as_customer_by_day`
- `sushi.waiter_revenue_by_day_v1`
- `sushi.waiter_revenue_by_day_v2`
- `sushi.waiters`
- `sushi_raw.items`
- `sushi_raw.order_items`
- `sushi_raw.orders`
- `sushi_raw.waiter_names`

Exited with code exit status 1

This is correct because after the migration in this PR runs and a new plan is made, the extra fields get read from disk and show as a metadata change because they go from being empty in state to being populated in state


metadata.extend(self._additional_metadata)

self._metadata_hash = hash_data(metadata)
Expand Down Expand Up @@ -3019,6 +3022,7 @@ def render_expression(
"formatting": str,
"optimize_query": str,
"virtual_environment_mode": lambda value: exp.Literal.string(value.value),
"dbt_node_info_": lambda value: value.to_expression(),
}


Expand Down
103 changes: 101 additions & 2 deletions sqlmesh/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,101 @@ def milliseconds(self) -> int:
return self.seconds * 1000


class DbtNodeInfo(PydanticModel):
"""
Represents dbt-specific model information set by the dbt loader and intended to be made available at the Snapshot level
(as opposed to hidden within the individual model jinja macro registries).
This allows for things like injecting implementations of variables / functions into the Jinja context that are compatible with
their dbt equivalents but are backed by the sqlmesh snapshots in any given plan / environment
"""

unique_id: str
"""This is the node/resource name/unique_id that's used as the node key in the dbt manifest.
It's prefixed by the resource type and is exposed in context variables like {{ selected_resources }}.
Examples:
- test.jaffle_shop.unique_stg_orders_order_id.e3b841c71a
- seed.jaffle_shop.raw_payments
- model.jaffle_shop.stg_orders
"""

name: str
"""Name of this object in the dbt global namespace, used by things like {{ ref() }} calls.
Examples:
- unique_stg_orders_order_id
- raw_payments
- stg_orders
"""

fqn: str
"""Used for selectors in --select/--exclude.
Takes the filesystem into account so may be structured differently to :unique_id.
Examples:
- jaffle_shop.staging.unique_stg_orders_order_id
- jaffle_shop.raw_payments
- jaffle_shop.staging.stg_orders
"""

alias: t.Optional[str] = None
"""This is dbt's way of overriding the _physical table_ a model is written to.
It's used in the following situation:
- Say you have two models, "stg_customers" and "customers"
- You want "stg_customers" to be written to the "staging" schema as eg "staging.customers" - NOT "staging.stg_customers"
- But you cant rename the file to "customers" because it will conflict with your other model file "customers"
- Even if you put it in a different folder, eg "staging/customers.sql" - dbt still has a global namespace so it will conflict
when you try to do something like "{{ ref('customers') }}"
- So dbt's solution to this problem is to keep calling it "stg_customers" at the dbt project/model level,
but allow overriding the physical table to "customers" via something like "{{ config(alias='customers', schema='staging') }}"
Note that if :alias is set, it does *not* replace :name at the model level and cannot be used interchangably with :name.
It also does not affect the :fqn or :unique_id. It's just used to override :name when it comes time to generate the physical table name.
"""

@model_validator(mode="after")
def post_init(self) -> Self:
# by default, dbt sets alias to the same as :name
# however, we only want to include :alias if it is actually different / actually providing an override
if self.alias == self.name:
self.alias = None
return self

def to_expression(self) -> exp.Expression:
"""Produce a SQLGlot expression representing this object, for use in things like the model/audit definition renderers"""
return exp.tuple_(
*(
exp.PropertyEQ(this=exp.var(k), expression=exp.Literal.string(v))
for k, v in sorted(self.model_dump(exclude_none=True).items())
)
)


class DbtInfoMixin:
"""This mixin encapsulates properties that only exist for dbt compatibility and are otherwise not required
for native projects"""

@property
def dbt_node_info(self) -> t.Optional[DbtNodeInfo]:
raise NotImplementedError()

@property
def dbt_unique_id(self) -> t.Optional[str]:
"""Used for compatibility with jinja context variables such as {{ selected_resources }}"""
if self.dbt_node_info:
return self.dbt_node_info.unique_id
return None

@property
def dbt_fqn(self) -> t.Optional[str]:
"""Used in the selector engine for compatibility with selectors that select models by dbt fqn"""
if self.dbt_node_info:
return self.dbt_node_info.fqn
return None


# this must be sorted in descending order
INTERVAL_SECONDS = {
IntervalUnit.YEAR: 60 * 60 * 24 * 365,
Expand All @@ -165,7 +260,7 @@ def milliseconds(self) -> int:
}


class _Node(PydanticModel):
class _Node(DbtInfoMixin, PydanticModel):
"""
Node is the core abstraction for entity that can be executed within the scheduler.
Expand Down Expand Up @@ -199,7 +294,7 @@ class _Node(PydanticModel):
interval_unit_: t.Optional[IntervalUnit] = Field(alias="interval_unit", default=None)
tags: t.List[str] = []
stamp: t.Optional[str] = None
dbt_name: t.Optional[str] = None # dbt node name
dbt_node_info_: t.Optional[DbtNodeInfo] = Field(alias="dbt_node_info", default=None)
_path: t.Optional[Path] = None
_data_hash: t.Optional[str] = None
_metadata_hash: t.Optional[str] = None
Expand Down Expand Up @@ -446,6 +541,10 @@ def is_audit(self) -> bool:
"""Return True if this is an audit node"""
return False

@property
def dbt_node_info(self) -> t.Optional[DbtNodeInfo]:
return self.dbt_node_info_


class NodeType(str, Enum):
MODEL = "model"
Expand Down
4 changes: 3 additions & 1 deletion sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,9 @@ def _run_or_audit(
run_environment_statements=run_environment_statements,
audit_only=audit_only,
auto_restatement_triggers=auto_restatement_triggers,
selected_models={s.node.dbt_name for s in merged_intervals if s.node.dbt_name},
selected_models={
s.node.dbt_unique_id for s in merged_intervals if s.node.dbt_unique_id
},
)

return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS
Expand Down
17 changes: 9 additions & 8 deletions sqlmesh/dbt/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sqlmesh.core.config.base import UpdateStrategy
from sqlmesh.core.config.common import VirtualEnvironmentMode
from sqlmesh.core.model import Model
from sqlmesh.core.node import DbtNodeInfo
from sqlmesh.dbt.column import (
ColumnConfig,
column_descriptions_to_sqlmesh,
Expand Down Expand Up @@ -120,8 +121,10 @@ class BaseModelConfig(GeneralConfig):
grain: t.Union[str, t.List[str]] = []

# DBT configuration fields
unique_id: str = ""
name: str = ""
package_name: str = ""
fqn: t.List[str] = []
schema_: str = Field("", alias="schema")
database: t.Optional[str] = None
alias: t.Optional[str] = None
Expand Down Expand Up @@ -273,12 +276,10 @@ def sqlmesh_config_fields(self) -> t.Set[str]:
return {"description", "owner", "stamp", "storage_format"}

@property
def node_name(self) -> str:
resource_type = getattr(self, "resource_type", "model")
node_name = f"{resource_type}.{self.package_name}.{self.name}"
if self.version:
node_name += f".v{self.version}"
return node_name
def node_info(self) -> DbtNodeInfo:
return DbtNodeInfo(
unique_id=self.unique_id, name=self.name, fqn=".".join(self.fqn), alias=self.alias
)

def sqlmesh_model_kwargs(
self,
Expand Down Expand Up @@ -349,8 +350,8 @@ def to_sqlmesh(
def _model_jinja_context(
self, context: DbtContext, dependencies: Dependencies
) -> t.Dict[str, t.Any]:
if context._manifest and self.node_name in context._manifest._manifest.nodes:
attributes = context._manifest._manifest.nodes[self.node_name].to_dict()
if context._manifest and self.unique_id in context._manifest._manifest.nodes:
attributes = context._manifest._manifest.nodes[self.unique_id].to_dict()
if dependencies.model_attrs.all_attrs:
model_node: AttributeDict[str, t.Any] = AttributeDict(attributes)
else:
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/dbt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ def to_sqlmesh(
extract_dependencies_from_query=False,
allow_partials=allow_partials,
virtual_environment_mode=virtual_environment_mode,
dbt_name=self.node_name,
dbt_node_info=self.node_info,
**optional_kwargs,
**model_kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/dbt/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def to_sqlmesh(
audit_definitions=audit_definitions,
virtual_environment_mode=virtual_environment_mode,
start=self.start or context.sqlmesh_config.model_defaults.start,
dbt_name=self.node_name,
dbt_node_info=self.node_info,
**kwargs,
)

Expand Down
Loading