@@ -147,10 +147,10 @@ def assign_population_pcs(
147147 If you have a Pandas Dataframe and have all PCs as an array in a single column, the `expand_pd_array_col`
148148 can be used to expand this column into multiple `PC` columns.
149149
150- :param pop_pc_pd : Input Hail Table or Pandas Dataframe
150+ :param pop_pca_scores : Input Hail Table or Pandas Dataframe
151151 :param pc_cols: Columns storing the PCs to use
152152 :param known_col: Column storing the known population labels
153- :param RandomForestClassifier fit: fit from a previously trained random forest model (i.e., the output from a previous RandomForestClassifier() call)
153+ :param fit: Fit from a previously trained random forest model (i.e., the output from a previous RandomForestClassifier() call)
154154 :param seed: Random seed
155155 :param prop_train: Proportion of known data used for training
156156 :param n_estimators: Number of trees to use in the RF model
@@ -163,7 +163,12 @@ def assign_population_pcs(
163163
164164 hail_input = isinstance (pop_pca_scores , hl .Table )
165165 if hail_input :
166- pop_pc_pd = pop_pca_scores .select (known_col , pca_scores = pc_cols ).to_pandas ()
166+ if not fit :
167+ pop_pca_scores = pop_pca_scores .select (known_col , pca_scores = pc_cols )
168+ else :
169+ pop_pca_scores = pop_pca_scores .select (pca_scores = pc_cols )
170+
171+ pop_pc_pd = pop_pca_scores .to_pandas ()
167172
168173 # Explode the PC array
169174 num_out_cols = min ([len (x ) for x in pop_pc_pd ["pca_scores" ].values .tolist ()])
@@ -175,12 +180,10 @@ def assign_population_pcs(
175180 else :
176181 pop_pc_pd = pop_pca_scores
177182
178- train_data = pop_pc_pd .loc [~ pop_pc_pd [known_col ].isnull ()]
179-
180- N = len (train_data )
181-
182183 # Split training data into subsamples for fitting and evaluating
183184 if not fit :
185+ train_data = pop_pc_pd .loc [~ pop_pc_pd [known_col ].isnull ()]
186+ N = len (train_data )
184187 random .seed (seed )
185188 train_subsample_ridx = random .sample (list (range (0 , N )), int (N * prop_train ))
186189 train_fit = train_data .iloc [train_subsample_ridx ]
0 commit comments