|
525 | 525 | "metadata": {}, |
526 | 526 | "outputs": [], |
527 | 527 | "source": [ |
528 | | - "# # random user ID. You can try any other ID\n", |
529 | | - "# sample_user_id = 11005" |
530 | | - ] |
531 | | - }, |
532 | | - { |
533 | | - "cell_type": "code", |
534 | | - "execution_count": null, |
535 | | - "metadata": {}, |
536 | | - "outputs": [], |
537 | | - "source": [ |
538 | | - "# featurestore_runtime = boto_session.client(service_name='sagemaker-featurestore-runtime', region_name=region)\n", |
539 | | - "\n", |
540 | | - "# feature_store_session = sagemaker.Session(\n", |
541 | | - "# boto_session=boto_session,\n", |
542 | | - "# sagemaker_client=sagemaker_client,\n", |
543 | | - "# sagemaker_featurestore_runtime_client=featurestore_runtime\n", |
544 | | - "# )" |
545 | | - ] |
546 | | - }, |
547 | | - { |
548 | | - "cell_type": "code", |
549 | | - "execution_count": null, |
550 | | - "metadata": {}, |
551 | | - "outputs": [], |
552 | | - "source": [ |
553 | | - "# # pull the sample user's 5 star preferences record from the feature store\n", |
554 | | - "# fg_response = featurestore_runtime.get_record(\n", |
555 | | - "# FeatureGroupName='user-5star-track-features-music-rec', \n", |
556 | | - "# RecordIdentifierValueAsString=str(sample_user_id)\n", |
557 | | - "# )\n", |
558 | | - "\n", |
559 | | - "# record = fg_response['Record']\n", |
560 | | - "# df_user = pd.DataFrame(record).set_index('FeatureName')\n", |
561 | | - "# df_user.to_csv(\"./data/sample_user.csv\")\n", |
562 | 528 | "df_user = pd.read_csv(\"./data/sample_user.csv\")\n", |
563 | 529 | "df_user = df_user.set_index('FeatureName')" |
564 | 530 | ] |
|
576 | 542 | "metadata": {}, |
577 | 543 | "outputs": [], |
578 | 544 | "source": [ |
579 | | - "# # pull a sample of the tracks data (multiple records) from the feature store using athena query\n", |
580 | | - "# fg_name_tracks_obj = FeatureGroup(name='track-features-music-rec', sagemaker_session=feature_store_session)\n", |
581 | | - "# tracks_query = fg_name_tracks_obj.athena_query()\n", |
582 | | - "# tracks_table = tracks_query.table_name\n", |
583 | | - "\n", |
584 | | - "# # use escaped quotes aound table name since it contains '-' symbols\n", |
585 | | - "# query_string = (\"SELECT * FROM \\\"{}\\\" LIMIT 1000\".format(tracks_table))\n", |
586 | | - "# print(\"Running \" + query_string)\n", |
587 | | - "\n", |
588 | | - "# # run Athena query. The output is loaded to a Pandas dataframe.\n", |
589 | | - "# tracks_query.run(query_string=query_string, output_location=f\"s3://{bucket}/{prefix}/query_results/\")\n", |
590 | | - "# tracks_query.wait()\n", |
591 | | - "# df_tracks = tracks_query.as_dataframe()\n", |
592 | | - "# df_tracks.to_csv(\"./data/sample_tracks.csv\")\n", |
593 | 545 | "df_tracks = pd.read_csv(\"./data/sample_tracks.csv\")" |
594 | 546 | ] |
595 | 547 | }, |
|
676 | 628 | "metadata": {}, |
677 | 629 | "outputs": [], |
678 | 630 | "source": [ |
679 | | - "df_train = pd.read_csv(train_data_uri)\n", |
| 631 | + "s3_client.download_file(bucket, f\"{prefix}/data/train/train_data.csv\", f\"train_data.csv\")\n", |
| 632 | + "df_train = pd.read_csv(\"train_data.csv\")\n", |
680 | 633 | "\n", |
681 | 634 | "label = 'rating'" |
682 | 635 | ] |
|
0 commit comments