diff --git a/config.go b/config.go index a2ef7a4..95d6775 100644 --- a/config.go +++ b/config.go @@ -285,7 +285,7 @@ func fmtErrors(msg string, errs []error) error { } // TODO: test cases -func applyAutoConvertFunctions(cfgs []structConfig) []structConfig { +func applyAutoConvertFunctions(localPackage string, cfgs []structConfig) []structConfig { // Index the structs by name so any struct can refer to conversion // functions for any other struct. byName := make(map[string]structConfig, len(cfgs)) @@ -294,8 +294,9 @@ func applyAutoConvertFunctions(cfgs []structConfig) []structConfig { } for structIdx, s := range cfgs { - imports := newImports() + imports := newImports(localPackage) imports.Add("", s.Target.Package) + imports.NeedInFile(s.Target.Package) for fieldIdx, f := range s.Fields { if _, ignored := s.IgnoreFields[f.SourceName]; ignored { diff --git a/generate.go b/generate.go index 3726752..d25571e 100644 --- a/generate.go +++ b/generate.go @@ -20,7 +20,7 @@ func generateFiles(cfg config, targets map[string]targetPkg) error { for _, group := range byOutput { var decls []ast.Decl - imports := newImports() + imports := newImports(cfg.SourcePkg.pkg.PkgPath) for _, sourceStruct := range group { t := targets[sourceStruct.Target.Package].Structs[sourceStruct.Target.Struct] @@ -92,6 +92,7 @@ func generateConversion(cfg structConfig, t targetStruct, imports *imports) (gen var g generated imports.Add("", cfg.Target.Package) + imports.NeedInFile(cfg.Target.Package) targetType := &ast.SelectorExpr{ X: &ast.Ident{Name: path.Base(imports.AliasFor(cfg.Target.Package))}, @@ -129,6 +130,13 @@ func generateConversion(cfg structConfig, t targetStruct, imports *imports) (gen Sel: &ast.Ident{Name: name}, } + if _, pkg := importFromType(sourceField.SourceType, imports); pkg != "" { + imports.Add("", pkg) // source packages + } + if _, pkg := importFromType(field.Type(), imports); pkg != "" { + imports.Add("", pkg) // target packages + } + if sourceField.FuncTo != "" || sourceField.FuncFrom != "" { to.Body.List = append(to.Body.List, newAssignStmtUserFunc( targetExpr, @@ -165,7 +173,7 @@ func generateConversion(cfg structConfig, t targetStruct, imports *imports) (gen } // the assignmentKind is := so target==LHS source==RHS - rawKind, ok := computeAssignment(field.Type(), sourceField.SourceType) + rawKind, ok := computeAssignment(field.Type(), sourceField.SourceType, imports) if !ok { assignErrFn(nil) continue @@ -400,24 +408,34 @@ func writeFile(output string, contents []byte) error { } type imports struct { - byPkgPath map[string]string // package => alias(or default) - byAlias map[string]string // alias(or default) => package - hasAlias map[string]struct{} // package is using a non-default name + byPkgPath map[string]string // package => alias(or default) + byAlias map[string]string // alias(or default) => package + hasAlias map[string]struct{} // package is using a non-default name + neededInFile map[string]struct{} // package import is required in file + localPackage string } -func newImports() *imports { +func newImports(localPackage string) *imports { return &imports{ - byPkgPath: make(map[string]string), - byAlias: make(map[string]string), - hasAlias: make(map[string]struct{}), + byPkgPath: make(map[string]string), + byAlias: make(map[string]string), + hasAlias: make(map[string]struct{}), + neededInFile: make(map[string]struct{}), + localPackage: localPackage, + } +} + +func (i *imports) NeedInFile(pkgPath string) { + if _, exists := i.byPkgPath[pkgPath]; !exists { + panic("Only call NeedInFile after Add") } + + i.neededInFile[pkgPath] = struct{}{} } // Add an import with an optional alias. If no alias is specified, the default // alias will be path.Base(). The alias for a package should always be looked up // from AliasFor. -// -// TODO: remove alias arg? func (i *imports) Add(alias string, pkgPath string) { if _, exists := i.byPkgPath[pkgPath]; exists { return @@ -446,6 +464,9 @@ func (i *imports) Add(alias string, pkgPath string) { } func (i *imports) AliasFor(pkgPath string) string { + if pkgPath == i.localPackage { + return "" + } return i.byPkgPath[pkgPath] } @@ -454,7 +475,11 @@ func (i *imports) Decl() *ast.GenDecl { paths := make([]string, 0, len(i.byPkgPath)) for pkgPath := range i.byPkgPath { - paths = append(paths, pkgPath) + if pkgPath != i.localPackage { + if _, ok := i.neededInFile[pkgPath]; ok { + paths = append(paths, pkgPath) + } + } } sort.Strings(paths) diff --git a/generate_test.go b/generate_test.go index 8d33654..d166145 100644 --- a/generate_test.go +++ b/generate_test.go @@ -76,7 +76,7 @@ func TestGenerateConversion(t *testing.T) { newField("ID", types.Typ[types.String]), }, } - imports := newImports() + imports := newImports("TODO") gen, err := generateConversion(c, target, imports) assert.NilError(t, err) @@ -115,14 +115,14 @@ func TestGenerateConversion_WithMissingSourceField(t *testing.T) { newField("Name", types.Typ[types.String]), }, } - imports := newImports() + imports := newImports("TODO") _, err := generateConversion(c, target, imports) expected := "struct Node is missing field Name. Add the missing field or exclude it" assert.ErrorContains(t, err, expected) } func TestImports(t *testing.T) { - imp := newImports() + imp := newImports("TODO") t.Run("add duplicate import", func(t *testing.T) { imp.Add("", "example.com/foo") @@ -150,6 +150,12 @@ func TestImports(t *testing.T) { t.Skip("Decls value depends on previous subtests") } t.Run("Decls", func(t *testing.T) { + imp.NeedInFile("example.com/foo") + imp.NeedInFile("example.com/some/foo") + imp.NeedInFile("example.com/stars") + + imp.Add("", "example.com/totally-ignored") + file := &ast.File{Name: &ast.Ident{Name: "src"}} file.Decls = append(file.Decls, imp.Decl()) out, err := astToBytes(&token.FileSet{}, file) diff --git a/internal/e2e/core/cluster_node.go b/internal/e2e/core/cluster_node.go index bde0f12..3dc0ab4 100644 --- a/internal/e2e/core/cluster_node.go +++ b/internal/e2e/core/cluster_node.go @@ -9,6 +9,9 @@ type ClusterNode struct { Label Label // Labels []Label // WorkPointer []*Workload + InnerLabel inner.Label + InnerLabel2 inner.Label + InnerLabel3 inner.Label O *Other I inner.Inner diff --git a/internal/e2e/core/inner/inner.go b/internal/e2e/core/inner/inner.go index d156dfe..e5c45a5 100644 --- a/internal/e2e/core/inner/inner.go +++ b/internal/e2e/core/inner/inner.go @@ -3,3 +3,5 @@ package inner type Inner struct { M string } + +type Label string diff --git a/internal/e2e/sourcepkg/node.go b/internal/e2e/sourcepkg/node.go index fa48614..fd7d98e 100644 --- a/internal/e2e/sourcepkg/node.go +++ b/internal/e2e/sourcepkg/node.go @@ -3,6 +3,7 @@ package sourcepkg import ( "github.com/hashicorp/mog/internal/e2e/core" "github.com/hashicorp/mog/internal/e2e/core/inner" + "github.com/hashicorp/mog/internal/e2e/sourcepkg/outer" ) // Node source structure for e2e testing mog. @@ -20,6 +21,9 @@ type Node struct { Meta map[string]interface{} Work []Workload // WorkPointer []*Workload + InnerLabel string + InnerLabel2 outer.Label + InnerLabel3 LocalLabel O *core.Other I inner.Inner @@ -65,6 +69,8 @@ type Node struct { type StringSlice []string type WorkloadSlice []Workload +type LocalLabel string + // mog annotation: // // name=Core diff --git a/internal/e2e/sourcepkg/outer/outer.go b/internal/e2e/sourcepkg/outer/outer.go new file mode 100644 index 0000000..43a25c8 --- /dev/null +++ b/internal/e2e/sourcepkg/outer/outer.go @@ -0,0 +1,3 @@ +package outer + +type Label string diff --git a/main.go b/main.go index 3ae9f43..1782319 100644 --- a/main.go +++ b/main.go @@ -87,7 +87,7 @@ func runMog(opts options) error { return fmt.Errorf("failed to load targets: %w", err) } - cfg.Structs = applyAutoConvertFunctions(cfg.Structs) + cfg.Structs = applyAutoConvertFunctions(cfg.SourcePkg.pkg.PkgPath, cfg.Structs) log.Printf("Generating code for %d structs", len(cfg.Structs)) diff --git a/mapping.go b/mapping.go index 1ef80a3..c32dd3f 100644 --- a/mapping.go +++ b/mapping.go @@ -176,7 +176,7 @@ func convertibleButNotIdentical(typ, typeDecode types.Type) bool { // // If this is not possible, or not currently supported (nil, false) is // returned. -func computeAssignment(leftType, rightType types.Type) (assignmentKind, bool) { +func computeAssignment(leftType, rightType types.Type, imports *imports) (assignmentKind, bool) { // First check if the types are naturally directly assignable. Only allow // type pairs that are symmetrically assignable for simplicity. if types.AssignableTo(rightType, leftType) { @@ -201,6 +201,15 @@ func computeAssignment(leftType, rightType types.Type) (assignmentKind, bool) { if convertibleButNotIdentical(rightType, rightTypeDecode) || convertibleButNotIdentical(leftType, leftTypeDecode) { + if _, pkg := importFromType(leftType, imports); pkg != "" { + imports.Add("", pkg) + imports.NeedInFile(pkg) + } + if _, pkg := importFromType(rightType, imports); pkg != "" { + imports.Add("", pkg) + imports.NeedInFile(pkg) + } + return &singleAssignmentKind{ Left: leftType, Right: rightType, @@ -238,7 +247,7 @@ func computeAssignment(leftType, rightType types.Type) (assignmentKind, bool) { } // the elements have to be assignable - rawOp, ok := computeAssignment(left.Elem(), right.Elem()) + rawOp, ok := computeAssignment(left.Elem(), right.Elem(), imports) if !ok { return nil, false } @@ -262,7 +271,7 @@ func computeAssignment(leftType, rightType types.Type) (assignmentKind, bool) { return nil, false } - rawKeyOp, ok := computeAssignment(left.Key(), right.Key()) + rawKeyOp, ok := computeAssignment(left.Key(), right.Key(), imports) if !ok { return nil, false } @@ -277,7 +286,7 @@ func computeAssignment(leftType, rightType types.Type) (assignmentKind, bool) { } // the map values have to be assignable - rawOp, ok := computeAssignment(left.Elem(), right.Elem()) + rawOp, ok := computeAssignment(left.Elem(), right.Elem(), imports) if !ok { return nil, false } diff --git a/testdata/TestE2E-expected-node_gen.go b/testdata/TestE2E-expected-node_gen.go index effa7b4..acb110b 100644 --- a/testdata/TestE2E-expected-node_gen.go +++ b/testdata/TestE2E-expected-node_gen.go @@ -2,7 +2,11 @@ package sourcepkg -import "github.com/hashicorp/mog/internal/e2e/core" +import ( + "github.com/hashicorp/mog/internal/e2e/core" + "github.com/hashicorp/mog/internal/e2e/core/inner" + "github.com/hashicorp/mog/internal/e2e/sourcepkg/outer" +) func NodeToCore(s *Node, t *core.ClusterNode) { if s == nil { @@ -10,6 +14,9 @@ func NodeToCore(s *Node, t *core.ClusterNode) { } t.ID = s.ID t.Label = core.Label(s.Label) + t.InnerLabel = inner.Label(s.InnerLabel) + t.InnerLabel2 = inner.Label(s.InnerLabel2) + t.InnerLabel3 = inner.Label(s.InnerLabel3) t.O = s.O t.I = s.I WorkloadToCore(&s.F1, &t.F1) @@ -173,6 +180,9 @@ func NodeFromCore(t *core.ClusterNode, s *Node) { } s.ID = t.ID s.Label = string(t.Label) + s.InnerLabel = string(t.InnerLabel) + s.InnerLabel2 = outer.Label(t.InnerLabel2) + s.InnerLabel3 = LocalLabel(t.InnerLabel3) s.O = t.O s.I = t.I WorkloadFromCore(&t.F1, &s.F1) diff --git a/types.go b/types.go index 90e5fae..ad3949c 100644 --- a/types.go +++ b/types.go @@ -8,6 +8,31 @@ import ( "path" ) +func importFromType(t types.Type, imports *imports) (alias, pkg string) { + if os.Getenv("DEBUG_MOG") == "1" { + defer func() { + fmt.Printf("IMPORT-FROM-TYPE: [%T :: %+v] => [%q, %q]\n", + t, t, alias, pkg, + ) + }() + } + + switch x := t.(type) { + case *types.Basic: + return "", "" + case *types.Named: + target := x.Obj() + return target.Pkg().Name(), target.Pkg().Path() + case *types.Pointer: + return importFromType(x.Elem(), imports) + // case *types.Slice: + // case *types.Map: + // case *types.Interface: + default: + return "", "" + } +} + // typeToExpr converts a go/types representation of a type into a go/ast // representation of a type. //