@@ -162,7 +162,7 @@ class NerDLGraphChecker(override val uid: String)
162162 new Param [String ](this , " graphFolder" , " Folder path that contain external graph files" )
163163
164164 /** @group getParam */
165- private def getGraphFolder : Option [String ] = get(graphFolder)
165+ protected def getGraphFolder : Option [String ] = get(graphFolder)
166166
167167 /** Extracts the graph hyperparameters from the training data (dataset).
168168 *
@@ -177,7 +177,7 @@ class NerDLGraphChecker(override val uid: String)
177177 * a tuple containing the number of labels, number of unique characters, and the embedding
178178 * dim
179179 */
180- private def getGraphParamsDs (
180+ protected def getGraphParamsDs (
181181 dataset : Dataset [_],
182182 inputCols : Array [String ],
183183 labelsCol : String ): (Int , Int , Int ) = {
@@ -219,14 +219,16 @@ class NerDLGraphChecker(override val uid: String)
219219 (nLabels, nChars, embeddingsDim)
220220 }
221221
222+ protected def searchForSuitableGraph (nLabels : Int , nChars : Int , embeddingsDim : Int ): String =
223+ NerDLApproach .searchForSuitableGraph(nLabels, embeddingsDim, nChars + 1 , getGraphFolder)
224+
222225 override def fit (dataset : Dataset [_]): NerDLGraphCheckerModel = {
223226 val (nLabels, nChars, embeddingsDim) =
224227 getGraphParamsDs(dataset, $(inputCols), $(labelColumn))
225228
226229 // Throws exception if no suitable graph found
227230 Try {
228- NerDLApproach
229- .searchForSuitableGraph(nLabels, embeddingsDim, nChars + 1 , getGraphFolder)
231+ searchForSuitableGraph(nLabels, nChars, embeddingsDim)
230232 } match {
231233 case Failure (exception : IllegalArgumentException ) =>
232234 throw new IllegalArgumentException (" NerDLGraphChecker: " + exception.getMessage)
0 commit comments