Skip to content

Commit

Permalink
Merge pull request #9 from percolate/embedded_interfaces
Browse files Browse the repository at this point in the history
Support Embedded Interfaces
  • Loading branch information
kevinbirch authored Dec 1, 2017
2 parents 93c6aee + 12a6142 commit 56d4c3d
Show file tree
Hide file tree
Showing 54 changed files with 1,710 additions and 395 deletions.
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version: 2
jobs:
build:
docker:
- image: circleci/golang:1.9.0
- image: circleci/golang:1.9.2-stretch

working_directory: /go/src/github.com/percolate/charlatan

Expand All @@ -14,7 +14,7 @@ jobs:
shell: /bin/bash
name: go fmt
command: |
! gofmt -l *.go testdata/*.go 2>&1 | read
! gofmt -l $(find . -path ./vendor -prune -o -type f -name '*.go' -print) 2>&1 | read
- run: go vet
- run: make test
- run: sbin/codecov -s build/coverage/
Expand Down
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ DIAGRAMS := $(DIAGRAM_DIR)/architecture.png
BUILD_DIR := build
COVERAGE_DIR := $(BUILD_DIR)/coverage
TESTDATA_SOURCES := $(shell find testdata -name "*_def.go")
IGNORED_TESTDATA := testdata/__def.go testdata/emptier_def.go
IGNORED_TESTDATA := testdata/_/__def.go testdata/emptier/emptier_def.go
GENERATED_TESTDATA := $(subst _def,,$(filter-out $(IGNORED_TESTDATA),$(TESTDATA_SOURCES)))

all: test
Expand Down Expand Up @@ -41,7 +41,8 @@ charlatan:

# Get the capitalized interface name from the filename and pass it to charlatan
%.go: %_def.go
iface=$(*F); ./charlatan -file=$< -output=$@ $${iface^}
rm -f $@
iface=$(*F); ./charlatan -dir=testdata/$(*F) -output=$@ $${iface^}

test: $(COVERAGE_DIR)
go test -v -coverprofile=$(TOP_DIR)/$(COVERAGE_DIR)/$(@F)_coverage.out -covermode=atomic ./...
Expand Down
2 changes: 1 addition & 1 deletion VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.2
0.9.3
34 changes: 7 additions & 27 deletions charlatan.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,12 @@ var (
outputPath = flag.String("output", "", "output file path [default: ./charlatan.go]")
outputPackage = flag.String("package", "", "output package name [default: \"<current package>\"]")
dirName = flag.String("dir", "", "input package directory [default: current package directory]")
fileNames stringSliceValue
)

func init() {
log.SetFlags(0)
log.SetPrefix("charlatan: ")
flag.Usage = usage
flag.Var(&fileNames, "file", "name of input file, may be repeated, ignored if -dir is present")
}

func usage() {
Expand All @@ -71,32 +69,14 @@ func main() {
os.Exit(1)
}

var (
g *Generator
err error
)

packageDirectory := "."
if *dirName != "" {
g, err = LoadPackageDir(*dirName)
if err != nil {
log.Fatal(err)
}
} else if len(fileNames) != 0 {
for _, name := range fileNames[1:] {
if *dirName != filepath.Dir(name) {
log.Fatal("all input source files must be in the same package directory")
}
}
g, err = LoadPackageFiles(fileNames)
if err != nil {
log.Fatal(err)
}
} else {
// process the package in current directory.
g, err = LoadPackageDir(".")
if err != nil {
log.Fatal(err)
}
packageDirectory = *dirName
}

g, err := LoadPackageDir(packageDirectory)
if err != nil {
log.Fatal(err)
}

g.PackageOverride = *outputPackage
Expand Down
34 changes: 17 additions & 17 deletions endtoend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

// This file contains a test that compiles and runs each program in testdata
// after generating the mocks for its interface. The rule is that for
// testdata/x.go we run `charlatan -file=testdata/x_def.go X` and then compile
// testdata/x.go we run `charlatan -dir=testdata X` and then compile
// and run the testdata/x.go program. The resulting binary panics if the mock
// structs are broken, including for error cases.

Expand All @@ -26,34 +26,34 @@ type endToEndTest struct {

func (e *endToEndTest) compileAndRun(t *testing.T) {
t.Parallel()
dir, err := ioutil.TempDir("", "charlatan")
tempdir, err := ioutil.TempDir("", "charlatan")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)

source := filepath.Join(dir, path.Base(e.file))
err = copy(source, e.file)
if err != nil {
t.Fatalf("copying end-to-end test file to temporary directory: %s", err)
}
defer os.RemoveAll(tempdir)

base := strings.TrimSuffix(path.Base(e.file), "_ete.go")
interfaceName := strings.Title(base)

sourceDef := filepath.Join(dir, base+"_def.go")
err = copy(sourceDef, filepath.Join("testdata", base+"_def.go"))
sourceDef := filepath.Join(tempdir, base+"_def.go")
err = copy(sourceDef, filepath.Join("testdata/"+base, base+"_def.go"))
if err != nil {
t.Fatalf("copying end-to-end interface definition file to temporary directory: %s", err)
t.Fatalf("copying interface definition file to temporary directory: %s", err)
}

charlatanSource := filepath.Join(dir, interfaceName+"_charlatan.go")
charlatanSource := filepath.Join(tempdir, interfaceName+"_charlatan.go")
// Run charlatan in temporary directory.
err = run(e.exe, "-file", sourceDef, "-output", charlatanSource, "-package", "main", interfaceName)
err = run(e.exe, "-dir", tempdir, "-output", charlatanSource, "-package", "main", interfaceName)
if err != nil {
t.Fatal(err)
}

source := filepath.Join(tempdir, path.Base(e.file))
err = copy(source, e.file)
if err != nil {
t.Fatalf("copying end-to-end test file to temporary directory: %s", err)
}

// Run the binary in the temporary directory.
err = run("go", "run", charlatanSource, sourceDef, source)
if err != nil {
Expand All @@ -62,19 +62,19 @@ func (e *endToEndTest) compileAndRun(t *testing.T) {
}

func TestEndToEnd(t *testing.T) {
dir, err := ioutil.TempDir("", "charlatan")
tempdir, err := ioutil.TempDir("", "charlatan")
if err != nil {
t.Fatal(err)
}

// Create charlatan in temporary directory.
charlatan := filepath.Join(dir, "charlatan.exe")
charlatan := filepath.Join(tempdir, "charlatan.exe")
err = run("go", "build", "-o", charlatan)
if err != nil {
t.Fatalf("building charlatan: %s", err)
}

names, err := filepath.Glob("testdata/*_ete.go")
names, err := filepath.Glob("testdata/ete/*_ete.go")
if err != nil {
t.Fatalf("finding end-to-end test files: %s", err)
}
Expand Down
145 changes: 111 additions & 34 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@ func LoadPackageDir(directory string) (*Generator, error) {
return parsePackage(directory, names)
}

// LoadPackageFiles parses a package using only the given files.
func LoadPackageFiles(names []string) (*Generator, error) {
return parsePackage(".", names)
}

func parsePackage(directory string, filenames []string) (*Generator, error) {
generator := &Generator{
imports: new(ImportSet),
Expand All @@ -55,6 +50,7 @@ func parsePackage(directory string, filenames []string) (*Generator, error) {
files := make([]*ast.File, 0, len(filenames))
fileset := token.NewFileSet()
importer := defaultImporter()

for _, filename := range filenames {
if !strings.HasSuffix(filename, ".go") {
continue
Expand All @@ -63,29 +59,31 @@ func parsePackage(directory string, filenames []string) (*Generator, error) {
if err != nil {
return nil, fmt.Errorf("syntax error: %s", err)
}
if err := generator.extractImports(file, importer); err != nil {
if err := generator.processImports(file, importer); err != nil {
return nil, err
}
if err := generator.extractInterfaces(file); err != nil {
if err := generator.processInterfaces(file); err != nil {
return nil, err
}
files = append(files, file)
}
if len(files) == 0 {
return nil, fmt.Errorf("error: no Go files found in %s", directory)
}
generator.packageName = files[0].Name.Name

// Type check the package.
// N.B. - type check the package
config := types.Config{Importer: importer, Error: func(err error) { fmt.Fprintln(os.Stderr, err) }}
if _, err := config.Check(directory, fileset, files, nil); err != nil {
pkg, err := config.Check(directory, fileset, files, nil)
if err != nil {
return nil, fmt.Errorf("type check failed")
}

generator.packageName = pkg.Name()

return generator, nil
}

func (g *Generator) extractImports(file *ast.File, importer types.Importer) error {
func (g *Generator) processImports(file *ast.File, importer types.Importer) error {
for _, spec := range file.Imports {
path, err := strconv.Unquote(spec.Path.Value)
if err != nil {
Expand All @@ -95,32 +93,76 @@ func (g *Generator) extractImports(file *ast.File, importer types.Importer) erro
if err != nil {
return err
}
decl := &Import{
Name: pkg.Name(),
Path: spec.Path.Value,

g.processImport(spec, pkg)
if err := g.processImportInterfaces(pkg); err != nil {
return err
}

if spec.Name == nil {
g.imports.Add(decl)
}

return nil
}

func (g *Generator) processImport(spec *ast.ImportSpec, pkg *types.Package) {
decl := &Import{
Name: pkg.Name(),
Path: spec.Path.Value,
}

if spec.Name == nil {
g.imports.Add(decl)
return
}

switch spec.Name.Name {
case "_":
break
case ".":
decl.Required = true
decl.Alias = "."
default:
decl.Alias = spec.Name.Name
}

g.imports.Add(decl)
}

func (g *Generator) processImportInterfaces(pkg *types.Package) error {
for _, name := range pkg.Scope().Names() {
obj := pkg.Scope().Lookup(name)

qname := fmt.Sprintf("%s.%s", pkg.Name(), obj.Name())
if _, exists := g.interfaces[qname]; exists {
continue
}

switch spec.Name.Name {
case "_":
if _, isType := obj.(*types.TypeName); !isType || !obj.Exported() || !types.IsInterface(obj.Type()) {
continue
case ".":
decl.Required = true
decl.Alias = "."
default:
decl.Alias = spec.Name.Name
}
g.imports.Add(decl)

ifType := obj.Type().Underlying().(*types.Interface)
decl := &Interface{
Name: obj.Name(),
}

for i := 0; i < ifType.NumMethods(); i++ {
m := ifType.Method(i)
if !m.Exported() {
continue
}
if err := decl.addMethodFromType(m, g.imports); err != nil {
return err
}
}

g.interfaces[qname] = decl
}

return nil
}

func (g *Generator) extractInterfaces(file *ast.File) (err error) {
func (g *Generator) processInterfaces(file *ast.File) error {
for _, node := range file.Decls {
gen, ok := node.(*ast.GenDecl)
if !ok || gen.Tok != token.TYPE {
Expand All @@ -132,22 +174,39 @@ func (g *Generator) extractInterfaces(file *ast.File) (err error) {
continue
}

decl := &Interface{
Name: spec.Name.Name,
decl, err := g.processInterface(spec.Name.Name, ifType)
if err != nil {
return err
}
g.interfaces[spec.Name.Name] = decl
}

return nil
}

func (g *Generator) processInterface(name string, ifType *ast.InterfaceType) (*Interface, error) {
decl := &Interface{
Name: name,
}

for _, method := range ifType.Methods.List {
if _, ok := method.Type.(*ast.FuncType); ok {
err = decl.addMethod(method, g.imports)
if err != nil {
return
}
for _, field := range ifType.Methods.List {
switch f := field.Type.(type) {
case *ast.FuncType:
if err := decl.addMethodFromField(field, g.imports); err != nil {
return nil, err
}
case *ast.Ident:
// N.B. - embedded interface from current package
decl.embeds = append(decl.embeds, f.Name)
case *ast.SelectorExpr:
// N.B. - embedded interface from imported package
decl.embeds = append(decl.embeds, fmt.Sprintf("%s.%s", f.X.(*ast.Ident).String(), f.Sel.String()))
default:
return nil, fmt.Errorf("internal error: unsupported interface field: %#v\n", field.Type)
}
}

return
return decl, nil
}

// Generate produces the charlatan source file data for the named interfaces.
Expand All @@ -167,6 +226,24 @@ func (g *Generator) Generate(interfaceNames []string) ([]byte, error) {
continue
}
decls = append(decls, decl)
if len(decl.embeds) == 0 {
continue
}

embeddedMethods := []*Method{}
for _, embedName := range decl.embeds {
embed, ok := g.interfaces[embedName]
if !ok {
return nil, fmt.Errorf("error: interface %q embedded in %s not found", embedName, name)
}

for _, m := range embed.Methods {
c := *m
c.Interface = decl.Name
embeddedMethods = append(embeddedMethods, &c)
}
}
decl.Methods = append(embeddedMethods, decl.Methods...)
}

if len(decls) == 0 {
Expand Down
Loading

0 comments on commit 56d4c3d

Please sign in to comment.