11import  contextlib 
22import  contextvars 
3+ from  functools  import  lru_cache 
34import  json 
45import  pathlib 
56from  typing  import  Any , Dict , Iterator , List , Optional 
@@ -106,7 +107,54 @@ def _clean_gql_response(response: Any) -> Any:
106107        return  response 
107108
108109
109- @mcp .tool (description = "Get an entity by its DataHub URN." ) 
110+ class  SemanticVersionStruct (BaseModel ):
111+     semantic_version : str 
112+     version_stamp : str 
113+ 
114+     @classmethod  
115+     def  from_dict (cls , data : Dict [str , Any ]) ->  "SemanticVersionStruct" :
116+         return  cls (
117+             semantic_version = data ["semanticVersion" ],
118+             version_stamp = data ["versionStamp" ],
119+         )
120+ 
121+ 
122+ class  SchemaVersionList (BaseModel ):
123+     latest_version : SemanticVersionStruct 
124+     versions : list [SemanticVersionStruct ]
125+ 
126+ 
127+ def  _get_schema_version_list (
128+     datahub_client : DataHubClient , dataset_urn : str 
129+ ) ->  SchemaVersionList  |  None :
130+     variables  =  {
131+         "input" : {
132+             "datasetUrn" : dataset_urn ,
133+         }
134+     }
135+     resp  =  _execute_graphql (
136+         datahub_client ._graph ,
137+         query = entity_details_fragment_gql ,
138+         variables = variables ,
139+         operation_name = "getSchemaVersionList" ,
140+     )
141+     if  not  (raw_schema_versions  :=  resp .get ("getSchemaVersionList" )):
142+         return  None 
143+ 
144+     return  SchemaVersionList (
145+         latest_version = SemanticVersionStruct .from_dict (
146+             raw_schema_versions .get ("latestVersion" , {})
147+         ),
148+         versions = [
149+             SemanticVersionStruct .from_dict (structs )
150+             for  structs  in  raw_schema_versions .get ("semanticVersionList" , [])
151+         ],
152+     )
153+ 
154+ 
155+ @mcp .tool ( 
156+     description = "Get an entity by its DataHub URN. This also provide schema_version_list(latest version, all versions) if available."  
157+ ) 
110158def  get_entity (urn : str ) ->  dict :
111159    client  =  get_client ()
112160
@@ -125,6 +173,12 @@ def get_entity(urn: str) -> dict:
125173
126174    _inject_urls_for_urns (client ._graph , result , ["" ])
127175
176+     if  schema_version_list  :=  _get_schema_version_list (client , urn ):
177+         result ["schemaVersionList" ] =  {
178+             "latestVersion" : schema_version_list .latest_version .semantic_version ,
179+             "versions" : sorted ([v .semantic_version  for  v  in  schema_version_list .versions ]),
180+         }
181+ 
128182    return  _clean_gql_response (result )
129183
130184
@@ -313,6 +367,34 @@ def get_lineage(urn: str, upstream: bool, max_hops: int = 1) -> dict:
313367    return  lineage 
314368
315369
370+ @mcp .tool (description = "Get schema from a dataset by its URN and version." ) 
371+ @lru_cache  
372+ def  get_versioned_dataset (dataset_urn : str , semantic_version : str ) ->  dict [str , Any ]:
373+     client  =  get_client ()
374+ 
375+     if  not  (schema_version_list  :=  _get_schema_version_list (client , dataset_urn )):
376+         raise  ValueError (f"No schema_version_list found for dataset { dataset_urn }  )
377+ 
378+     version_stamp_mapping  =  {
379+         struct .semantic_version : struct .version_stamp 
380+         for  struct  in  schema_version_list .versions 
381+     }
382+ 
383+     if  not  (target_version_stamp  :=  version_stamp_mapping .get (semantic_version )):
384+         raise  ValueError (
385+             f"Version '{ semantic_version } { dataset_urn }  
386+         )
387+ 
388+     variables  =  {"urn" : dataset_urn , "versionStamp" : target_version_stamp }
389+     resp  =  _execute_graphql (
390+         client ._graph ,
391+         query = entity_details_fragment_gql ,
392+         variables = variables ,
393+         operation_name = "getVersionedDataset" ,
394+     )
395+     return  resp .get ("versionedDataset" , {})
396+ 
397+ 
316398if  __name__  ==  "__main__" :
317399    import  sys 
318400
@@ -348,3 +430,6 @@ def _divider() -> None:
348430    _divider ()
349431    print ("Getting queries" , urn )
350432    print (json .dumps (get_dataset_queries (urn ), indent = 2 ))
433+     _divider ()
434+     print (json .dumps (get_versioned_dataset (urn , sementic_version = "0.0.0" ), indent = 2 ))
435+     _divider ()
0 commit comments