diff --git a/format/format.go b/format/format.go index c5709c3..79bcd51 100644 --- a/format/format.go +++ b/format/format.go @@ -8,13 +8,18 @@ import ( ) var ( - importRe = regexp.MustCompile(`(?ms)import \(([^)]+)\)`) - otherRe = regexp.MustCompile(`^(?:var|const|func)\s`) - ErrNoImports = errors.New("no imports found") + importRe = regexp.MustCompile(`(?ms)^import \(([^)]+)\)`) + singleImportRe = regexp.MustCompile(`(?ms)^import "[^"]+"`) + otherRe = regexp.MustCompile(`^(?:var|const|func)\s`) + ErrNoImports = errors.New("no imports found") ) // Source formats a given src's imports func Source(src []byte, module string) ([]byte, error) { + if singleImportRe.Match(src) { + return src, nil + } + importStart := importRe.FindIndex(src) if importStart == nil { return nil, fmt.Errorf("could not find imports: %w", ErrNoImports) diff --git a/format/format_test.go b/format/format_test.go index 8c60602..f53a763 100644 --- a/format/format_test.go +++ b/format/format_test.go @@ -22,6 +22,9 @@ func TestSource(t *testing.T) { formatted, err := Source(before, module) assert.NoErr(err) // Should be able to format before block assert.True(bytes.Equal(formatted, after)) // Formatted should match after block + + _, err = Source(singleImport, module) + assert.NoErr(err) // Should not get an error for single import } var ( @@ -34,6 +37,11 @@ func main() { s := "import \"fmt\"" _ = s }`) + singleImport = []byte(`package main + +import "fmt" + +func main() {}`) before = []byte(`package main import (