Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def whoami(
Returns the current authenticated user
"""
user = await User.get_by_username(session, current_user.username)
return UserOutput.from_orm(user)
return UserOutput.model_validate(user, from_attributes=True)


@router.get("/token/")
Expand Down
7 changes: 5 additions & 2 deletions datajunction-server/datajunction_server/api/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ async def list_attributes(
List all available attribute types.
"""
attributes = await AttributeType.get_all(session)
return [AttributeTypeBase.from_orm(attr) for attr in attributes]
return [
AttributeTypeBase.model_validate(attr, from_attributes=True)
for attr in attributes
]


@router.post(
Expand All @@ -62,7 +65,7 @@ async def add_attribute_type(
message=f"Attribute type `{data.name}` already exists!",
)
attribute_type = await AttributeType.create(session, data)
return AttributeTypeBase.from_orm(attribute_type)
return AttributeTypeBase.model_validate(attribute_type, from_attributes=True)


async def default_attribute_types(session: AsyncSession = Depends(get_session)):
Expand Down
6 changes: 3 additions & 3 deletions datajunction-server/datajunction_server/api/catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def list_catalogs(
"""
statement = select(Catalog).options(joinedload(Catalog.engines))
return [
CatalogInfo.from_orm(catalog)
CatalogInfo.model_validate(catalog, from_attributes=True)
for catalog in (await session.execute(statement)).unique().scalars()
]

Expand Down Expand Up @@ -111,7 +111,7 @@ async def add_catalog(
await session.commit()
await session.refresh(catalog, ["engines"])

return CatalogInfo.from_orm(catalog)
return CatalogInfo.model_validate(catalog, from_attributes=True)


@router.post(
Expand All @@ -136,7 +136,7 @@ async def add_engines_to_catalog(
session.add(catalog)
await session.commit()
await session.refresh(catalog)
return CatalogInfo.from_orm(catalog)
return CatalogInfo.model_validate(catalog, from_attributes=True)


async def list_new_engines(
Expand Down
2 changes: 1 addition & 1 deletion datajunction-server/datajunction_server/api/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def create_a_collection(
await session.commit()
await session.refresh(collection)

return CollectionInfo.from_orm(collection)
return CollectionInfo.model_validate(collection, from_attributes=True)


@router.delete("/collections/{name}", status_code=HTTPStatus.NO_CONTENT)
Expand Down
2 changes: 1 addition & 1 deletion datajunction-server/datajunction_server/api/cubes.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ async def get_cube_dimension_values(
value=row[0 : count_column[0]] if count_column else row,
count=row[count_column[0]] if count_column else None,
)
for row in result.results.__root__[0].rows
for row in result.results.root[0].rows
]
return DimensionValues( # pragma: no cover
dimensions=[
Expand Down
26 changes: 17 additions & 9 deletions datajunction-server/datajunction_server/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,15 @@ async def add_availability_state(
table=data.table,
valid_through_ts=data.valid_through_ts,
url=data.url,
min_temporal_partition=data.min_temporal_partition,
max_temporal_partition=data.max_temporal_partition,
min_temporal_partition=[
str(part) for part in data.min_temporal_partition or []
],
max_temporal_partition=[
str(part) for part in data.max_temporal_partition or []
],
partitions=[
partition.dict() if not isinstance(partition, Dict) else partition
for partition in data.partitions # type: ignore
for partition in (data.partitions or [])
],
categorical_partitions=data.categorical_partitions,
temporal_partitions=data.temporal_partitions,
Expand All @@ -159,10 +163,14 @@ async def add_availability_state(
entity_type=EntityType.AVAILABILITY,
node=node.name, # type: ignore
activity_type=ActivityType.CREATE,
pre=AvailabilityStateBase.from_orm(old_availability).dict()
pre=AvailabilityStateBase.model_validate(
old_availability,
).model_dump()
if old_availability
else {},
post=AvailabilityStateBase.from_orm(node_revision.availability).dict(),
post=AvailabilityStateBase.model_validate(
node_revision.availability,
).model_dump(),
user=current_user.username,
),
session=session,
Expand Down Expand Up @@ -262,8 +270,8 @@ async def get_data(
)

# Inject column info if there are results
if result.results.__root__: # pragma: no cover
result.results.__root__[0].columns = generated_sql.columns # type: ignore
if result.results.root: # pragma: no cover
result.results.root[0].columns = generated_sql.columns # type: ignore
return result


Expand Down Expand Up @@ -447,8 +455,8 @@ async def get_data_for_metrics(
)

# Inject column info if there are results
if result.results.__root__: # pragma: no cover
result.results.__root__[0].columns = translated_sql.columns or []
if result.results.root: # pragma: no cover
result.results.root[0].columns = translated_sql.columns or []
return result


Expand Down
4 changes: 2 additions & 2 deletions datajunction-server/datajunction_server/api/djsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ async def get_data_for_djsql(
)

# Inject column info if there are results
if result.results.__root__: # pragma: no cover
result.results.__root__[0].columns = translated_sql.columns or []
if result.results.root: # pragma: no cover
result.results.root[0].columns = translated_sql.columns or []
return result


Expand Down
8 changes: 5 additions & 3 deletions datajunction-server/datajunction_server/api/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def list_engines(
List all available engines
"""
return [
EngineInfo.from_orm(engine)
EngineInfo.model_validate(engine)
for engine in (await session.execute(select(Engine))).scalars()
]

Expand All @@ -58,7 +58,9 @@ async def get_an_engine(
"""
Return an engine by name and version
"""
return EngineInfo.from_orm(await get_engine(session, name, version))
return EngineInfo.model_validate(
await get_engine(session, name, version),
)


@router.post(
Expand Down Expand Up @@ -95,4 +97,4 @@ async def add_engine(
await session.commit()
await session.refresh(engine)

return EngineInfo.from_orm(engine)
return EngineInfo.model_validate(engine)
10 changes: 5 additions & 5 deletions datajunction-server/datajunction_server/api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,8 @@ async def query_event_stream(
"query end state detected (%s), sending final event to the client",
query_next.state,
)
if query_next.results.__root__: # pragma: no cover
query_next.results.__root__[0].columns = columns or []
if query_next.results.root: # pragma: no cover
query_next.results.root[0].columns = columns or []
yield {
"event": "message",
"id": uuid.uuid4(),
Expand Down Expand Up @@ -879,13 +879,13 @@ def get_node_revision_materialization(
)
if materialization.strategy != MaterializationStrategy.INCREMENTAL_TIME:
info.urls = [info.urls[0]]
materialization_config_output = MaterializationConfigOutput.from_orm(
materialization_config_output = MaterializationConfigOutput.model_validate(
materialization,
)
materializations.append(
MaterializationConfigInfoUnified(
**materialization_config_output.dict(),
**info.dict(),
**materialization_config_output.model_dump(),
**info.model_dump(),
),
)
return materializations
4 changes: 2 additions & 2 deletions datajunction-server/datajunction_server/api/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def list_history(
offset=offset,
limit=limit,
)
return [HistoryOutput.from_orm(entry) for entry in hist]
return [HistoryOutput.model_validate(entry) for entry in hist]


@router.get("/history/", response_model=List[HistoryOutput])
Expand Down Expand Up @@ -86,4 +86,4 @@ async def list_history_by_node_context(
)
result = await session.execute(statement)
hist = result.scalars().all()
return [HistoryOutput.from_orm(entry) for entry in hist]
return [HistoryOutput.model_validate(entry) for entry in hist]
9 changes: 9 additions & 0 deletions datajunction-server/datajunction_server/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from fastapi.responses import JSONResponse
from fastapi_cache import FastAPICache
from fastapi_cache.backends.inmemory import InMemoryBackend
from fastapi_mcp import FastApiMCP
from starlette.middleware.cors import CORSMiddleware

from datajunction_server import __version__
Expand Down Expand Up @@ -172,3 +173,11 @@ async def dj_exception_handler(


app = create_app(lifespan=lifespan)
mcp = FastApiMCP(
app,
name="DataJunction API MCP",
description="DJ MCP server",
describe_all_responses=True,
describe_full_response_schema=True,
)
mcp.mount_http()
32 changes: 25 additions & 7 deletions datajunction-server/datajunction_server/api/materializations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from datetime import datetime, timezone
from http import HTTPStatus
from typing import Callable, List
from typing import Any, Callable, List

from fastapi import Depends, Request
from fastapi.responses import JSONResponse
Expand Down Expand Up @@ -56,6 +56,19 @@
router = SecureAPIRouter(tags=["materializations"])


def discriminate_materialization(data):
"""
Discriminate between UpsertMaterialization and UpsertCubeMaterialization based on job type
"""
if isinstance(data, dict) and "job" in data:
job_str = data["job"]
if job_str == "druid_cube":
return UpsertCubeMaterialization
else:
return UpsertMaterialization
return UpsertMaterialization


@router.get(
"/materialization/info",
status_code=200,
Expand Down Expand Up @@ -84,7 +97,7 @@ def materialization_jobs_info() -> JSONResponse:
)
async def upsert_materialization(
node_name: str,
data: UpsertMaterialization | UpsertCubeMaterialization,
data: dict[str, Any],
*,
session: AsyncSession = Depends(get_session),
request: Request,
Expand All @@ -99,6 +112,8 @@ async def upsert_materialization(
Add or update a materialization of the specified node. If a node_name is specified
for the materialization config, it will always update that named config.
"""
materialization_class = discriminate_materialization(data)
materialization = materialization_class.model_validate(data)
request_headers = dict(request.headers)
node = await Node.get_by_name(session, node_name, raise_if_not_exists=True)
if node.type == NodeType.SOURCE: # type: ignore
Expand All @@ -117,20 +132,20 @@ async def upsert_materialization(
current_revision = node.current # type: ignore
old_materializations = {mat.name: mat for mat in current_revision.materializations}

if data.strategy == MaterializationStrategy.INCREMENTAL_TIME:
if materialization.strategy == MaterializationStrategy.INCREMENTAL_TIME: # type: ignore
if not node.current.temporal_partition_columns(): # type: ignore
raise DJInvalidInputException(
http_status_code=HTTPStatus.BAD_REQUEST,
message="Cannot create materialization with strategy "
f"`{data.strategy}` without specifying a time partition column!",
f"`{materialization.strategy}` without specifying a time partition column!", # type: ignore
)

# Create a new materialization
new_materialization = await create_new_materialization(
session,
current_revision,
data,
validate_access,
materialization,
validate_access, # type: ignore
current_user=current_user,
)

Expand Down Expand Up @@ -266,7 +281,10 @@ async def upsert_materialization(
f"Successfully updated materialization config named `{new_materialization.name}` "
f"for node `{node_name}`"
),
"urls": [output.urls for output in materialization_response.values()],
"urls": [
[str(url) for url in output.urls]
for output in materialization_response.values()
],
},
)

Expand Down
2 changes: 1 addition & 1 deletion datajunction-server/datajunction_server/api/namespaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ async def hard_delete_node_namespace(
status_code=HTTPStatus.OK,
content={
"message": f"The namespace `{namespace}` has been completely removed.",
"impact": impacts.dict(),
"impact": impacts.model_dump(),
},
)

Expand Down
8 changes: 4 additions & 4 deletions datajunction-server/datajunction_server/api/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
NodeStatusDetails,
NodeValidation,
NodeValidationError,
SourceColumnOutput,
UpdateNode,
)
from datajunction_server.models.node_type import NodeType
Expand Down Expand Up @@ -337,7 +338,7 @@ async def get_node(
options=NodeOutput.load_options(),
raise_if_not_exists=True,
)
return NodeOutput.from_orm(node)
return NodeOutput.model_validate(node)


@router.delete("/nodes/{name}/")
Expand Down Expand Up @@ -606,15 +607,14 @@ async def register_table(
request_headers,
_catalog.engines[0] if len(_catalog.engines) >= 1 else None,
)

return await create_source(
data=CreateSourceNode(
catalog=catalog,
schema_=schema_,
table=table,
name=name,
display_name=name,
columns=[ColumnOutput.from_orm(col) for col in columns],
columns=[SourceColumnOutput.model_validate(col) for col in columns],
description="This source node was automatically created as a registered table.",
mode=NodeMode.PUBLISHED,
),
Expand Down Expand Up @@ -702,7 +702,7 @@ async def register_view(
table=view,
name=node_name,
display_name=node_name,
columns=[ColumnOutput.from_orm(col) for col in columns],
columns=[ColumnOutput.model_validate(col) for col in columns],
description="This source node was automatically created as a registered view.",
mode=NodeMode.PUBLISHED,
query=query,
Expand Down
Loading
Loading