Skip to content

Commit 00347a9

Browse files
committed
feat: add schema tools
- Add listing schema_versions(latest_version, versions) functionality to response of `get_entity()` tool. - Add `get_versioned_dataset` tool for retrieving schema by version.
1 parent 5ccb13a commit 00347a9

File tree

4 files changed

+147
-1
lines changed

4 files changed

+147
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Supports both DataHub Core and DataHub Cloud.
1111
- Fetching metadata for any entity
1212
- Traversing the lineage graph, both upstream and downstream
1313
- Listing SQL queries associated with a dataset
14+
- Listing schema versions(latest version, all versions) and fetching schema by specific version.
1415

1516
## Demo
1617

src/mcp_server_datahub/gql/entity_details.gql

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,3 +1205,57 @@ query GetEntityLineage($input: SearchAcrossLineageInput!) {
12051205
}
12061206
}
12071207
}
1208+
1209+
query getSchemaVersionList($input: GetSchemaVersionListInput!) {
1210+
getSchemaVersionList(input: $input) {
1211+
latestVersion {
1212+
semanticVersion
1213+
versionStamp
1214+
__typename
1215+
}
1216+
semanticVersionList {
1217+
semanticVersion
1218+
versionStamp
1219+
__typename
1220+
}
1221+
__typename
1222+
}
1223+
}
1224+
1225+
query getVersionedDataset($urn: String!, $versionStamp: String) {
1226+
versionedDataset(urn: $urn, versionStamp: $versionStamp) {
1227+
schema {
1228+
fields {
1229+
fieldPath
1230+
jsonPath
1231+
nullable
1232+
description
1233+
type
1234+
nativeDataType
1235+
recursive
1236+
isPartOfKey
1237+
isPartitioningKey
1238+
__typename
1239+
}
1240+
lastObserved
1241+
__typename
1242+
}
1243+
editableSchemaMetadata {
1244+
editableSchemaFieldInfo {
1245+
fieldPath
1246+
description
1247+
globalTags {
1248+
...globalTagsFields
1249+
__typename
1250+
}
1251+
glossaryTerms {
1252+
...glossaryTerms
1253+
__typename
1254+
}
1255+
__typename
1256+
}
1257+
__typename
1258+
}
1259+
__typename
1260+
}
1261+
}

src/mcp_server_datahub/mcp_server.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import contextvars
3+
from functools import lru_cache
34
import json
45
import pathlib
56
from typing import Any, Dict, Iterator, List, Optional
@@ -106,7 +107,54 @@ def _clean_gql_response(response: Any) -> Any:
106107
return response
107108

108109

109-
@mcp.tool(description="Get an entity by its DataHub URN.")
110+
class SemanticVersionStruct(BaseModel):
111+
semantic_version: str
112+
version_stamp: str
113+
114+
@classmethod
115+
def from_dict(cls, data: Dict[str, Any]) -> "SemanticVersionStruct":
116+
return cls(
117+
semantic_version=data["semanticVersion"],
118+
version_stamp=data["versionStamp"],
119+
)
120+
121+
122+
class SchemaVersionList(BaseModel):
123+
latest_version: SemanticVersionStruct
124+
versions: list[SemanticVersionStruct]
125+
126+
127+
def _get_schema_version_list(
128+
datahub_client: DataHubClient, dataset_urn: str
129+
) -> SchemaVersionList | None:
130+
variables = {
131+
"input": {
132+
"datasetUrn": dataset_urn,
133+
}
134+
}
135+
resp = _execute_graphql(
136+
datahub_client._graph,
137+
query=entity_details_fragment_gql,
138+
variables=variables,
139+
operation_name="getSchemaVersionList",
140+
)
141+
if not (raw_schema_versions := resp.get("getSchemaVersionList")):
142+
return None
143+
144+
return SchemaVersionList(
145+
latest_version=SemanticVersionStruct.from_dict(
146+
raw_schema_versions.get("latestVersion", {})
147+
),
148+
versions=[
149+
SemanticVersionStruct.from_dict(structs)
150+
for structs in raw_schema_versions.get("semanticVersionList", [])
151+
],
152+
)
153+
154+
155+
@mcp.tool(
156+
description="Get an entity by its DataHub URN. This also provide schema_version_list(latest version, all versions) if available."
157+
)
110158
def get_entity(urn: str) -> dict:
111159
client = get_client()
112160

@@ -125,6 +173,12 @@ def get_entity(urn: str) -> dict:
125173

126174
_inject_urls_for_urns(client._graph, result, [""])
127175

176+
if schema_version_list := _get_schema_version_list(client, urn):
177+
result["schemaVersionList"] = {
178+
"latestVersion": schema_version_list.latest_version.semantic_version,
179+
"versions": sorted([v.semantic_version for v in schema_version_list.versions]),
180+
}
181+
128182
return _clean_gql_response(result)
129183

130184

@@ -313,6 +367,34 @@ def get_lineage(urn: str, upstream: bool, max_hops: int = 1) -> dict:
313367
return lineage
314368

315369

370+
@mcp.tool(description="Get schema from a dataset by its URN and version.")
371+
@lru_cache
372+
def get_versioned_dataset(dataset_urn: str, semantic_version: str) -> dict[str, Any]:
373+
client = get_client()
374+
375+
if not (schema_version_list := _get_schema_version_list(client, dataset_urn)):
376+
raise ValueError(f"No schema_version_list found for dataset {dataset_urn}")
377+
378+
version_stamp_mapping = {
379+
struct.semantic_version: struct.version_stamp
380+
for struct in schema_version_list.versions
381+
}
382+
383+
if not (target_version_stamp := version_stamp_mapping.get(semantic_version)):
384+
raise ValueError(
385+
f"Version '{semantic_version}' not found for dataset '{dataset_urn}'"
386+
)
387+
388+
variables = {"urn": dataset_urn, "versionStamp": target_version_stamp}
389+
resp = _execute_graphql(
390+
client._graph,
391+
query=entity_details_fragment_gql,
392+
variables=variables,
393+
operation_name="getVersionedDataset",
394+
)
395+
return resp.get("versionedDataset", {})
396+
397+
316398
if __name__ == "__main__":
317399
import sys
318400

@@ -348,3 +430,6 @@ def _divider() -> None:
348430
_divider()
349431
print("Getting queries", urn)
350432
print(json.dumps(get_dataset_queries(urn), indent=2))
433+
_divider()
434+
print(json.dumps(get_versioned_dataset(urn, sementic_version="0.0.0"), indent=2))
435+
_divider()

tests/test_mcp_server.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
get_dataset_queries,
1111
get_entity,
1212
get_lineage,
13+
get_versioned_dataset,
1314
search,
1415
with_client,
1516
)
@@ -58,6 +59,11 @@ def test_search() -> None:
5859
assert res is not None
5960

6061

62+
def test_get_versioned_dataset() -> None:
63+
res = get_versioned_dataset(_test_urn, "0.0.0")
64+
assert res is not None
65+
66+
6167
if __name__ == "__main__":
6268
import pytest
6369

0 commit comments

Comments
 (0)