2525Created: 2020-08-06
2626"""
2727
28+ from argparse import ArgumentParser
2829from contextlib import contextmanager
2930import filecmp
3031from glob import glob
3940from git .refs .head import Head
4041import pandas as pd
4142
43+ from .utils import read_params
44+
4245Files = List [str ]
4346FileDiffMap = Dict [str , Optional [str ]]
4447
48+
4549def diff_export_csv (
4650 before_csv : str ,
4751 after_csv : str
@@ -65,7 +69,8 @@ def diff_export_csv(
6569 added_df is the pd.DataFrame of added rows from after_csv.
6670 """
6771
68- export_csv_dtypes = {"geo_id" : str , "val" : float , "se" : float , "sample_size" : float }
72+ export_csv_dtypes = {"geo_id" : str , "val" : float ,
73+ "se" : float , "sample_size" : float }
6974
7075 before_df = pd .read_csv (before_csv , dtype = export_csv_dtypes )
7176 before_df .set_index ("geo_id" , inplace = True )
@@ -89,6 +94,42 @@ def diff_export_csv(
8994 after_df_cmn .loc [~ (same_mask .all (axis = 1 )), :],
9095 after_df .loc [added_idx , :])
9196
97+
98+ def run_module (archive_type : str ,
99+ cache_dir : str ,
100+ export_dir : str ,
101+ ** kwargs ):
102+ """Builds and runs an ArchiveDiffer.
103+
104+ Parameters
105+ ----------
106+ archive_type: str
107+ Type of ArchiveDiffer to run. Must be one of ["git", "s3"] which correspond to `GitArchiveDiffer` and `S3ArchiveDiffer`, respectively.
108+ cache_dir: str
109+ The directory for storing most recent archived/uploaded CSVs to start diffing from.
110+ export_dir: str
111+ The directory with most recent exported CSVs to diff to.
112+ **kwargs:
113+ Keyword arguments corresponding to constructor arguments for the respective ArchiveDiffers.
114+ """
115+ if archive_type == "git" :
116+ arch_diff = GitArchiveDiffer (cache_dir ,
117+ export_dir ,
118+ kwargs ["branch_name" ],
119+ kwargs ["override_dirty" ],
120+ kwargs ["commit_partial_success" ],
121+ kwargs ["commit_message" ])
122+ elif archive_type == "s3" :
123+ arch_diff = S3ArchiveDiffer (cache_dir ,
124+ export_dir ,
125+ kwargs ["bucket_name" ],
126+ kwargs ["indicator_prefix" ],
127+ kwargs ["aws_credentials" ])
128+ else :
129+ raise ValueError (f"No archive type named '{ archive_type } '" )
130+ arch_diff .run ()
131+
132+
92133class ArchiveDiffer :
93134 """
94135 Base class for performing diffing and archiving of exported covidcast CSVs
@@ -140,12 +181,16 @@ def diff_exports(self) -> Tuple[Files, FileDiffMap, Files]:
140181 assert self ._cache_updated
141182
142183 # Glob to only pick out CSV files, ignore hidden files
143- previous_files = set (basename (f ) for f in glob (join (self .cache_dir , "*.csv" )))
144- exported_files = set (basename (f ) for f in glob (join (self .export_dir , "*.csv" )))
184+ previous_files = set (basename (f )
185+ for f in glob (join (self .cache_dir , "*.csv" )))
186+ exported_files = set (basename (f )
187+ for f in glob (join (self .export_dir , "*.csv" )))
145188
146- deleted_files = sorted (join (self .cache_dir , f ) for f in previous_files - exported_files )
189+ deleted_files = sorted (join (self .cache_dir , f )
190+ for f in previous_files - exported_files )
147191 common_filenames = sorted (exported_files & previous_files )
148- new_files = sorted (join (self .export_dir , f ) for f in exported_files - previous_files )
192+ new_files = sorted (join (self .export_dir , f )
193+ for f in exported_files - previous_files )
149194
150195 common_diffs : Dict [str , Optional [str ]] = {}
151196 for filename in common_filenames :
@@ -158,11 +203,13 @@ def diff_exports(self) -> Tuple[Files, FileDiffMap, Files]:
158203 if filecmp .cmp (before_file , after_file , shallow = False ):
159204 continue
160205
161- deleted_df , changed_df , added_df = diff_export_csv (before_file , after_file )
206+ deleted_df , changed_df , added_df = diff_export_csv (
207+ before_file , after_file )
162208 new_issues_df = pd .concat ([changed_df , added_df ], axis = 0 )
163209
164210 if len (deleted_df ) > 0 :
165- print (f"Warning, diff has deleted indices in { after_file } that will be ignored" )
211+ print (
212+ f"Warning, diff has deleted indices in { after_file } that will be ignored" )
166213
167214 # Write the diffs to diff_file, if applicable
168215 if len (new_issues_df ) > 0 :
@@ -220,6 +267,29 @@ def filter_exports(self, common_diffs: FileDiffMap):
220267 else :
221268 replace (diff_file , exported_file )
222269
270+ def run (self ):
271+ """Runs the differ and archives the changed and new files."""
272+ self .update_cache ()
273+
274+ # Diff exports, and make incremental versions
275+ _ , common_diffs , new_files = self .diff_exports ()
276+
277+ # Archive changed and new files only
278+ to_archive = [f for f , diff in common_diffs .items ()
279+ if diff is not None ]
280+ to_archive += new_files
281+ _ , fails = self .archive_exports (to_archive )
282+
283+ # Filter existing exports to exclude those that failed to archive
284+ succ_common_diffs = {f : diff for f ,
285+ diff in common_diffs .items () if f not in fails }
286+ self .filter_exports (succ_common_diffs )
287+
288+ # Report failures: someone should probably look at them
289+ for exported_file in fails :
290+ print (f"Failed to archive '{ exported_file } '" )
291+
292+
223293class S3ArchiveDiffer (ArchiveDiffer ):
224294 """
225295 AWS S3 backend for archving
@@ -263,11 +333,14 @@ def update_cache(self):
263333 For making sure cache_dir is updated with all latest files from the S3 bucket.
264334 """
265335 # List all indicator-related objects from S3
266- archive_objects = self .bucket .objects .filter (Prefix = self .indicator_prefix ).all ()
267- archive_objects = [obj for obj in archive_objects if obj .key .endswith (".csv" )]
336+ archive_objects = self .bucket .objects .filter (
337+ Prefix = self .indicator_prefix ).all ()
338+ archive_objects = [
339+ obj for obj in archive_objects if obj .key .endswith (".csv" )]
268340
269341 # Check against what we have locally and download missing ones
270- cached_files = set (basename (f ) for f in glob (join (self .cache_dir , "*.csv" )))
342+ cached_files = set (basename (f )
343+ for f in glob (join (self .cache_dir , "*.csv" )))
271344 for obj in archive_objects :
272345 archive_file = basename (obj .key )
273346 cached_file = join (self .cache_dir , archive_file )
@@ -297,7 +370,8 @@ def archive_exports(self, exported_files: Files) -> Tuple[Files, Files]:
297370 archive_fail = []
298371
299372 for exported_file in exported_files :
300- cached_file = abspath (join (self .cache_dir , basename (exported_file )))
373+ cached_file = abspath (
374+ join (self .cache_dir , basename (exported_file )))
301375 archive_key = join (self .indicator_prefix , basename (exported_file ))
302376
303377 try :
@@ -314,6 +388,7 @@ def archive_exports(self, exported_files: Files) -> Tuple[Files, Files]:
314388
315389 return archive_success , archive_fail
316390
391+
317392class GitArchiveDiffer (ArchiveDiffer ):
318393 """
319394 Local git repo backend for archiving
@@ -352,7 +427,7 @@ def __init__(
352427 super ().__init__ (cache_dir , export_dir )
353428
354429 assert override_dirty or not commit_partial_success , \
355- "Only can commit_partial_success=True when override_dirty=True"
430+ "Only can commit_partial_success=True when override_dirty=True"
356431
357432 # Assumes a repository is set up already, will raise exception if not found
358433 self .repo = Repo (cache_dir , search_parent_directories = True )
@@ -405,7 +480,8 @@ def update_cache(self):
405480 """
406481 # Make sure cache directory is clean: has everything nicely committed
407482 if not self .override_dirty :
408- cache_clean = not self .repo .is_dirty (untracked_files = True , path = abspath (self .cache_dir ))
483+ cache_clean = not self .repo .is_dirty (
484+ untracked_files = True , path = abspath (self .cache_dir ))
409485 assert cache_clean , f"There are uncommitted changes in the cache dir '{ self .cache_dir } '"
410486
411487 self ._cache_updated = True
@@ -439,11 +515,14 @@ def archive_exports(self, exported_files: Files) -> Tuple[Files, Files]:
439515 with self .archiving_branch ():
440516 # Abs paths of all modified files to check if we will override uncommitted changes
441517 working_tree_dir = self .repo .working_tree_dir
442- dirty_files = [join (working_tree_dir , f ) for f in self .repo .untracked_files ]
443- dirty_files += [join (working_tree_dir , d .a_path ) for d in self .repo .index .diff (None )]
518+ dirty_files = [join (working_tree_dir , f )
519+ for f in self .repo .untracked_files ]
520+ dirty_files += [join (working_tree_dir , d .a_path )
521+ for d in self .repo .index .diff (None )]
444522
445523 for exported_file in exported_files :
446- archive_file = abspath (join (self .cache_dir , basename (exported_file )))
524+ archive_file = abspath (
525+ join (self .cache_dir , basename (exported_file )))
447526
448527 # Archive and explicitly stage new export, depending if override
449528 if self .override_dirty or archive_file not in dirty_files :
@@ -469,11 +548,46 @@ def archive_exports(self, exported_files: Files) -> Tuple[Files, Files]:
469548 if len (exported_files ) > 0 :
470549
471550 # Support partial success and at least one archive succeeded
472- partial_success = self .commit_partial_success and len (archive_success ) > 0
551+ partial_success = self .commit_partial_success and len (
552+ archive_success ) > 0
473553
474554 if len (archive_success ) == len (exported_files ) or partial_success :
475555 self .repo .index .commit (message = self .commit_message )
476556
477557 self ._exports_archived = True
478558
479559 return archive_success , archive_fail
560+
561+
562+ if __name__ == "__main__" :
563+ parser = ArgumentParser ()
564+ parser .add_argument ("--archive_type" , required = True , type = str ,
565+ choices = ["git" , "s3" ],
566+ help = "Type of archive differ to use." )
567+ parser .add_argument ("--indicator_prefix" , type = str , default = "" ,
568+ help = "The prefix for S3 keys related to this indicator."
569+ " Required for `archive_type = 's3'" )
570+ parser .add_argument ("--branch_name" , type = str , default = "" ,
571+ help = " Branch to use for `archive_type` = 'git'." )
572+ parser .add_argument ("--override_dirty" , action = "store_true" ,
573+ help = "Whether to allow overwriting of untracked &"
574+ " uncommitted changes for `archive_type` = 'git'" )
575+ parser .add_argument ("--commit_partial_success" , action = "store_true" ,
576+ help = "Whether to still commit for `archive_type` = "
577+ "'git' even if some files were not archived and "
578+ "staged due to `override_dirty` = False." )
579+ parser .add_argument ("--commit_message" , type = str , default = "" ,
580+ help = "Commit message for `archive_type` = 'git'" )
581+ args = parser .parse_args ()
582+ params = read_params ()
583+ run_module (args .archive_type ,
584+ params .cache_dir ,
585+ params .export_dir ,
586+ aws_credentials = params .aws_credentials ,
587+ branch_name = args .branch_name ,
588+ bucket_name = params .bucket_name ,
589+ commit_message = args .commit_message ,
590+ commit_partial_success = args .commit_partial_success ,
591+ indicator_prefix = args .indicator_prefix ,
592+ override_dirty = args .override_dirty
593+ )
0 commit comments