3434from pyspark .sql import DataFrame
3535from pyspark .ml .param import Param , Params
3636from pyspark .ml .param .shared import HasMaxIter , HasInputCol
37- from pyspark .ml . pipeline import Estimator , Model , Pipeline , Transformer
37+ from pyspark .ml import Estimator , Model , Pipeline , Transformer
3838
3939
4040class MockDataset (DataFrame ):
@@ -49,6 +49,9 @@ def __init__(self):
4949 super (HasFake , self ).__init__ ()
5050 self .fake = Param (self , "fake" , "fake param" )
5151
52+ def getFake (self ):
53+ return self .getOrDefault (self .fake )
54+
5255
5356class MockTransformer (Transformer , HasFake ):
5457
@@ -71,6 +74,7 @@ def __init__(self):
7174 def _fit (self , dataset ):
7275 self .dataset_index = dataset .index
7376 model = MockModel ()
77+ self ._copyValues (model )
7478 return model
7579
7680
@@ -86,12 +90,13 @@ def test_pipeline(self):
8690 transformer1 = MockTransformer ()
8791 estimator2 = MockEstimator ()
8892 transformer3 = MockTransformer ()
89- pipeline = Pipeline () \
90- .setStages ([estimator0 , transformer1 , estimator2 , transformer3 ])
93+ pipeline = Pipeline (stages = [estimator0 , transformer1 , estimator2 , transformer3 ])
9194 pipeline_model = pipeline .fit (dataset , {estimator0 .fake : 0 , transformer1 .fake : 1 })
9295 model0 , transformer1 , model2 , transformer3 = pipeline_model .stages
9396 self .assertEqual (0 , model0 .dataset_index )
97+ self .assertEqual (0 , model0 .getFake ())
9498 self .assertEqual (1 , transformer1 .dataset_index )
99+ self .assertEqual (1 , transformer1 .getFake ())
95100 self .assertEqual (2 , dataset .index )
96101 self .assertIsNone (model2 .dataset_index , "The last model shouldn't be called in fit." )
97102 self .assertIsNone (transformer3 .dataset_index ,
0 commit comments