@@ -301,23 +301,29 @@ def validate_and_initialize_user_module(self):
301301 )
302302 self .log_func_implementation_found_or_not (load_fn , MODEL_FN )
303303 if load_fn is not None :
304- self .load_extra_arg = self .function_extra_arg (self .load , load_fn )
304+ self .load_extra_arg = self .function_extra_arg (HuggingFaceHandlerService .load , load_fn )
305305 self .load = load_fn
306306 self .log_func_implementation_found_or_not (preprocess_fn , INPUT_FN )
307307 if preprocess_fn is not None :
308- self .preprocess_extra_arg = self .function_extra_arg (self .preprocess , preprocess_fn )
308+ self .preprocess_extra_arg = self .function_extra_arg (
309+ HuggingFaceHandlerService .preprocess , preprocess_fn
310+ )
309311 self .preprocess = preprocess_fn
310312 self .log_func_implementation_found_or_not (predict_fn , PREDICT_FN )
311313 if predict_fn is not None :
312- self .predict_extra_arg = self .function_extra_arg (self .predict , predict_fn )
314+ self .predict_extra_arg = self .function_extra_arg (HuggingFaceHandlerService .predict , predict_fn )
313315 self .predict = predict_fn
314316 self .log_func_implementation_found_or_not (postprocess_fn , OUTPUT_FN )
315317 if postprocess_fn is not None :
316- self .postprocess_extra_arg = self .function_extra_arg (self .postprocess , postprocess_fn )
318+ self .postprocess_extra_arg = self .function_extra_arg (
319+ HuggingFaceHandlerService .postprocess , postprocess_fn
320+ )
317321 self .postprocess = postprocess_fn
318322 self .log_func_implementation_found_or_not (transform_fn , TRANSFORM_FN )
319323 if transform_fn is not None :
320- self .transform_extra_arg = self .function_extra_arg (self .transform_fn , transform_fn )
324+ self .transform_extra_arg = self .function_extra_arg (
325+ HuggingFaceHandlerService .transform_fn , transform_fn
326+ )
321327 self .transform_fn = transform_fn
322328 else :
323329 logger .info (
@@ -342,8 +348,15 @@ def function_extra_arg(self, default_func, func):
342348 1. the handle function takes context
343349 2. the handle function does not take context
344350 """
345- num_default_func_input = len (signature (default_func ).parameters )
346- num_func_input = len (signature (func ).parameters )
351+ default_params = signature (default_func ).parameters
352+ func_params = signature (func ).parameters
353+
354+ if "self" in default_params :
355+ num_default_func_input = len (default_params ) - 1
356+ else :
357+ num_default_func_input = len (default_params )
358+
359+ num_func_input = len (func_params )
347360 if num_default_func_input == num_func_input :
348361 # function takes context
349362 extra_args = [self .context ]
0 commit comments