Skip to content
This repository was archived by the owner on Mar 23, 2023. It is now read-only.

Commit dc044ec

Browse files
committed
Enable passing Python functions to Go for invocation.
1 parent 18fb95b commit dc044ec

File tree

2 files changed

+103
-5
lines changed

2 files changed

+103
-5
lines changed

runtime/function.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ func functionGet(_ *Frame, desc, instance *Object, owner *Type) (*Object, *BaseE
125125
return NewMethod(toFunctionUnsafe(desc), instance, owner).ToObject(), nil
126126
}
127127

128+
func functionNative(f *Frame, o *Object) (reflect.Value, *BaseException) {
129+
return reflect.ValueOf(o.Call), nil
130+
}
131+
128132
func functionRepr(_ *Frame, o *Object) (*Object, *BaseException) {
129133
fun := toFunctionUnsafe(o)
130134
return NewStr(fmt.Sprintf("<%s %s at %p>", fun.typ.Name(), fun.Name(), fun)).ToObject(), nil
@@ -134,6 +138,7 @@ func initFunctionType(map[string]*Object) {
134138
FunctionType.flags &= ^(typeFlagInstantiable | typeFlagBasetype)
135139
FunctionType.slots.Call = &callSlot{functionCall}
136140
FunctionType.slots.Get = &getSlot{functionGet}
141+
FunctionType.slots.Native = &nativeSlot{functionNative}
137142
FunctionType.slots.Repr = &unaryOpSlot{functionRepr}
138143
}
139144

runtime/native.go

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -489,22 +489,115 @@ func maybeConvertValue(f *Frame, o *Object, expectedRType reflect.Type) (reflect
489489
if raised != nil {
490490
return reflect.Value{}, raised
491491
}
492-
rtype := val.Type()
493492
for {
493+
rtype := val.Type()
494494
if rtype == expectedRType {
495495
return val, nil
496496
}
497497
if rtype.ConvertibleTo(expectedRType) {
498498
return val.Convert(expectedRType), nil
499499
}
500-
if rtype.Kind() == reflect.Ptr {
500+
switch rtype.Kind() {
501+
case reflect.Ptr:
501502
val = val.Elem()
502-
rtype = val.Type()
503503
continue
504+
505+
case reflect.Func:
506+
if fn, ok := val.Interface().(func(*Frame, Args, KWArgs) (*Object, *BaseException)); ok {
507+
val = nativeToPyFuncBridge(fn, expectedRType)
508+
continue
509+
}
504510
}
505-
break
511+
return val, f.RaiseType(TypeErrorType, fmt.Sprintf("cannot convert %s to %s", rtype, expectedRType))
512+
}
513+
}
514+
515+
var baseExceptionReflectType = reflect.TypeOf((*BaseException)(nil))
516+
517+
// pyToNativeRaised supports pushing a `raised` exception from python code to
518+
// native calling code. If the raised exception can't be returned to native
519+
// code, then the raised exception is panic-ed.
520+
func pyToNativeRaised(outs []reflect.Type, raised *BaseException) []reflect.Value {
521+
last := len(outs) - 1
522+
if len(outs) == 0 || outs[last] != baseExceptionReflectType {
523+
panic(raised)
524+
}
525+
ret := make([]reflect.Value, len(outs))
526+
for i, out := range outs[:last] {
527+
ret[i] = reflect.Zero(out)
506528
}
507-
return reflect.Value{}, f.RaiseType(TypeErrorType, fmt.Sprintf("cannot convert %s to %s", rtype, expectedRType))
529+
ret[last] = reflect.ValueOf(raised)
530+
return ret
531+
}
532+
533+
var frameReflectType = reflect.TypeOf((*Frame)(nil))
534+
535+
func nativeToPyFuncBridge(fn func(*Frame, Args, KWArgs) (*Object, *BaseException), target reflect.Type) reflect.Value {
536+
firstInIsFrame := target.NumIn() > 0 && target.In(0) == frameReflectType
537+
538+
outs := make([]reflect.Type, target.NumOut())
539+
for i := range outs {
540+
outs[i] = target.Out(i)
541+
}
542+
543+
return reflect.MakeFunc(target, func(args []reflect.Value) []reflect.Value {
544+
var f *Frame
545+
if firstInIsFrame {
546+
f, args = args[0].Interface().(*Frame), args[1:]
547+
} else {
548+
f = NewRootFrame()
549+
}
550+
551+
pyArgs := f.MakeArgs(len(args))
552+
for i, arg := range args {
553+
var raised *BaseException
554+
pyArgs[i], raised = WrapNative(f, arg)
555+
if raised != nil {
556+
return pyToNativeRaised(outs, raised)
557+
}
558+
}
559+
560+
ret, raised := fn(f, pyArgs, nil)
561+
f.FreeArgs(pyArgs)
562+
if raised != nil {
563+
return pyToNativeRaised(outs, raised)
564+
}
565+
566+
switch len(outs) {
567+
case 0:
568+
if ret != nil && ret != None {
569+
return pyToNativeRaised(outs, f.RaiseType(TypeErrorType, fmt.Sprintf("unexpected return of %v when None expected", ret)))
570+
}
571+
return nil
572+
573+
case 1:
574+
v, raised := maybeConvertValue(f, ret, outs[0])
575+
if raised != nil {
576+
return pyToNativeRaised(outs, raised)
577+
}
578+
return []reflect.Value{v}
579+
580+
default:
581+
converted := make([]reflect.Value, 0, len(outs))
582+
if raised := seqForEach(f, ret, func(o *Object) *BaseException {
583+
i := len(converted)
584+
if i >= len(outs) {
585+
return f.RaiseType(TypeErrorType, fmt.Sprintf("return value too long, want %d items", len(outs)))
586+
}
587+
v, raised := maybeConvertValue(f, o, outs[i])
588+
converted = append(converted, v)
589+
return raised
590+
}); raised != nil {
591+
return pyToNativeRaised(outs, raised)
592+
}
593+
594+
if len(converted) != len(outs) {
595+
return pyToNativeRaised(outs, f.RaiseType(TypeErrorType, fmt.Sprintf("return value wrong size %d, want %d", len(converted), len(outs))))
596+
}
597+
598+
return converted
599+
}
600+
})
508601
}
509602

510603
func nativeFuncTypeName(rtype reflect.Type) string {

0 commit comments

Comments
 (0)