diff --git a/tools/goctl/api/cmd.go b/tools/goctl/api/cmd.go index 0863805285eb..c16fd1729097 100644 --- a/tools/goctl/api/cmd.go +++ b/tools/goctl/api/cmd.go @@ -72,6 +72,7 @@ func init() { goCmdFlags.StringVar(&gogen.VarStringHome, "home") goCmdFlags.StringVar(&gogen.VarStringRemote, "remote") goCmdFlags.StringVar(&gogen.VarStringBranch, "branch") + goCmdFlags.BoolVar(&gogen.VarBoolWithTest, "test") goCmdFlags.StringVarWithDefaultValue(&gogen.VarStringStyle, "style", config.DefaultFormat) javaCmdFlags.StringVar(&javagen.VarStringDir, "dir") diff --git a/tools/goctl/api/gogen/gen.go b/tools/goctl/api/gogen/gen.go index 6f568fe5b783..676cfc37601a 100644 --- a/tools/goctl/api/gogen/gen.go +++ b/tools/goctl/api/gogen/gen.go @@ -38,7 +38,8 @@ var ( // VarStringBranch describes the branch. VarStringBranch string // VarStringStyle describes the style of output files. - VarStringStyle string + VarStringStyle string + VarBoolWithTest bool ) // GoCommand gen go project files from command line @@ -49,6 +50,7 @@ func GoCommand(_ *cobra.Command, _ []string) error { home := VarStringHome remote := VarStringRemote branch := VarStringBranch + withTest := VarBoolWithTest if len(remote) > 0 { repo, _ := util.CloneIntoGitHome(remote, branch) if len(repo) > 0 { @@ -66,11 +68,11 @@ func GoCommand(_ *cobra.Command, _ []string) error { return errors.New("missing -dir") } - return DoGenProject(apiFile, dir, namingStyle) + return DoGenProject(apiFile, dir, namingStyle, withTest) } // DoGenProject gen go project files with api file -func DoGenProject(apiFile, dir, style string) error { +func DoGenProject(apiFile, dir, style string, withTest bool) error { api, err := parser.Parse(apiFile) if err != nil { return err @@ -100,6 +102,10 @@ func DoGenProject(apiFile, dir, style string) error { logx.Must(genHandlers(dir, rootPkg, cfg, api)) logx.Must(genLogic(dir, rootPkg, cfg, api)) logx.Must(genMiddleware(dir, cfg, api)) + if withTest { + logx.Must(genHandlersTest(dir, rootPkg, cfg, api)) + logx.Must(genLogicTest(dir, rootPkg, cfg, api)) + } if err := backupAndSweep(apiFile); err != nil { return err diff --git a/tools/goctl/api/gogen/gen_test.go b/tools/goctl/api/gogen/gen_test.go index 56d5f9250140..5fe8207c0954 100644 --- a/tools/goctl/api/gogen/gen_test.go +++ b/tools/goctl/api/gogen/gen_test.go @@ -348,7 +348,7 @@ func validateWithCamel(t *testing.T, api, camel string) { assert.Nil(t, err) err = initMod(dir) assert.Nil(t, err) - err = DoGenProject(api, dir, camel) + err = DoGenProject(api, dir, camel, true) assert.Nil(t, err) filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { if strings.HasSuffix(path, ".go") { diff --git a/tools/goctl/api/gogen/genhandlerstest.go b/tools/goctl/api/gogen/genhandlerstest.go new file mode 100644 index 000000000000..ea913144cbb7 --- /dev/null +++ b/tools/goctl/api/gogen/genhandlerstest.go @@ -0,0 +1,80 @@ +package gogen + +import ( + _ "embed" + "fmt" + "strings" + + "github.com/zeromicro/go-zero/tools/goctl/api/spec" + "github.com/zeromicro/go-zero/tools/goctl/config" + "github.com/zeromicro/go-zero/tools/goctl/util" + "github.com/zeromicro/go-zero/tools/goctl/util/format" + "github.com/zeromicro/go-zero/tools/goctl/util/pathx" +) + +//go:embed handler_test.tpl +var handlerTestTemplate string + +func genHandlerTest(dir, rootPkg string, cfg *config.Config, group spec.Group, route spec.Route) error { + handler := getHandlerName(route) + handlerPath := getHandlerFolderPath(group, route) + pkgName := handlerPath[strings.LastIndex(handlerPath, "/")+1:] + logicName := defaultLogicPackage + if handlerPath != handlerDir { + handler = strings.Title(handler) + logicName = pkgName + } + filename, err := format.FileNamingFormat(cfg.NamingFormat, handler) + if err != nil { + return err + } + + return genFile(fileGenConfig{ + dir: dir, + subdir: getHandlerFolderPath(group, route), + filename: filename + "_test.go", + templateName: "handlerTestTemplate", + category: category, + templateFile: handlerTestTemplateFile, + builtinTemplate: handlerTestTemplate, + data: map[string]any{ + "PkgName": pkgName, + "ImportPackages": genHandlerTestImports(group, route, rootPkg), + "HandlerName": handler, + "RequestType": util.Title(route.RequestTypeName()), + "ResponseType": util.Title(route.ResponseTypeName()), + "LogicName": logicName, + "LogicType": strings.Title(getLogicName(route)), + "Call": strings.Title(strings.TrimSuffix(handler, "Handler")), + "HasResp": len(route.ResponseTypeName()) > 0, + "HasRequest": len(route.RequestTypeName()) > 0, + "HasDoc": len(route.JoinedDoc()) > 0, + "Doc": getDoc(route.JoinedDoc()), + }, + }) +} + +func genHandlersTest(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error { + for _, group := range api.Service.Groups { + for _, route := range group.Routes { + if err := genHandlerTest(dir, rootPkg, cfg, group, route); err != nil { + return err + } + } + } + + return nil +} + +func genHandlerTestImports(group spec.Group, route spec.Route, parentPkg string) string { + imports := []string{ + //fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, getLogicFolderPath(group, route))), + fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, contextDir)), + fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, configDir)), + } + if len(route.RequestTypeName()) > 0 { + imports = append(imports, fmt.Sprintf("\"%s\"\n", pathx.JoinPackages(parentPkg, typesDir))) + } + + return strings.Join(imports, "\n\t") +} diff --git a/tools/goctl/api/gogen/genlogictest.go b/tools/goctl/api/gogen/genlogictest.go new file mode 100644 index 000000000000..14f4ac21bd12 --- /dev/null +++ b/tools/goctl/api/gogen/genlogictest.go @@ -0,0 +1,90 @@ +package gogen + +import ( + _ "embed" + "fmt" + "strings" + + "github.com/zeromicro/go-zero/tools/goctl/api/spec" + "github.com/zeromicro/go-zero/tools/goctl/config" + "github.com/zeromicro/go-zero/tools/goctl/util/format" + "github.com/zeromicro/go-zero/tools/goctl/util/pathx" +) + +//go:embed logic_test.tpl +var logicTestTemplate string + +func genLogicTest(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error { + for _, g := range api.Service.Groups { + for _, r := range g.Routes { + err := genLogicTestByRoute(dir, rootPkg, cfg, g, r) + if err != nil { + return err + } + } + } + return nil +} + +func genLogicTestByRoute(dir, rootPkg string, cfg *config.Config, group spec.Group, route spec.Route) error { + logic := getLogicName(route) + goFile, err := format.FileNamingFormat(cfg.NamingFormat, logic) + if err != nil { + return err + } + + imports := genLogicTestImports(route, rootPkg) + var responseString string + var returnString string + var requestString string + var requestType string + if len(route.ResponseTypeName()) > 0 { + resp := responseGoTypeName(route, typesPacket) + responseString = "(resp " + resp + ", err error)" + returnString = "return" + } else { + responseString = "error" + returnString = "return nil" + } + if len(route.RequestTypeName()) > 0 { + requestString = "req *" + requestGoTypeName(route, typesPacket) + requestType = requestGoTypeName(route, typesPacket) + } + + subDir := getLogicFolderPath(group, route) + return genFile(fileGenConfig{ + dir: dir, + subdir: subDir, + filename: goFile + "_test.go", + templateName: "logicTestTemplate", + category: category, + templateFile: logicTestTemplateFile, + builtinTemplate: logicTestTemplate, + data: map[string]any{ + "pkgName": subDir[strings.LastIndex(subDir, "/")+1:], + "imports": imports, + "logic": strings.Title(logic), + "function": strings.Title(strings.TrimSuffix(logic, "Logic")), + "responseType": responseString, + "returnString": returnString, + "request": requestString, + "hasRequest": len(requestType) > 0, + "hasResponse": len(route.ResponseTypeName()) > 0, + "requestType": requestType, + "hasDoc": len(route.JoinedDoc()) > 0, + "doc": getDoc(route.JoinedDoc()), + }, + }) +} + +func genLogicTestImports(route spec.Route, parentPkg string) string { + var imports []string + //imports = append(imports, `"context"`+"\n") + imports = append(imports, fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, contextDir))) + imports = append(imports, fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, configDir))) + if shallImportTypesPackage(route) { + imports = append(imports, fmt.Sprintf("\"%s\"\n", pathx.JoinPackages(parentPkg, typesDir))) + } + //imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceURL)) + return strings.Join(imports, "\n\t") +} diff --git a/tools/goctl/api/gogen/handler_test.tpl b/tools/goctl/api/gogen/handler_test.tpl new file mode 100644 index 000000000000..f461f4ddda65 --- /dev/null +++ b/tools/goctl/api/gogen/handler_test.tpl @@ -0,0 +1,81 @@ +package {{.PkgName}} + +import ( + "bytes" + {{if .HasRequest}}"encoding/json"{{end}} + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + {{.ImportPackages}} +) + +{{if .HasDoc}}{{.Doc}}{{end}} +func Test{{.HandlerName}}(t *testing.T) { + // new service context + c := config.Config{} + svcCtx := svc.NewServiceContext(c) + // init mock service context here + + tests := []struct { + name string + reqBody interface{} + wantStatus int + wantResp string + setupMocks func() + }{ + { + name: "invalid request body", + reqBody: "invalid", + wantStatus: http.StatusBadRequest, + wantResp: "unsupported type", // Adjust based on actual error response + setupMocks: func() { + // No setup needed for this test case + }, + }, + { + name: "handler error", + {{if .HasRequest}}reqBody: types.{{.RequestType}}{ + //TODO: add fields here + }, + {{end}}wantStatus: http.StatusBadRequest, + wantResp: "error", // Adjust based on actual error response + setupMocks: func() { + // Mock login logic to return an error + }, + }, + { + name: "handler successful", + {{if .HasRequest}}reqBody: types.{{.RequestType}}{ + //TODO: add fields here + }, + {{end}}wantStatus: http.StatusOK, + wantResp: `{"code":0,"msg":"success","data":{}}`, // Adjust based on actual success response + setupMocks: func() { + // Mock login logic to return success + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupMocks() + var reqBody []byte + {{if .HasRequest}}var err error + reqBody, err = json.Marshal(tt.reqBody) + require.NoError(t, err){{end}} + req, err := http.NewRequest("POST", "/ut", bytes.NewBuffer(reqBody)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + rr := httptest.NewRecorder() + handler := {{.HandlerName}}(svcCtx) + handler.ServeHTTP(rr, req) + t.Log(rr.Body.String()) + assert.Equal(t, tt.wantStatus, rr.Code) + assert.Contains(t, rr.Body.String(), tt.wantResp) + }) + } +} \ No newline at end of file diff --git a/tools/goctl/api/gogen/logic_test.tpl b/tools/goctl/api/gogen/logic_test.tpl new file mode 100644 index 000000000000..3525be56fe62 --- /dev/null +++ b/tools/goctl/api/gogen/logic_test.tpl @@ -0,0 +1,69 @@ +package {{.pkgName}} + +import ( + "context" + "testing" + + {{.imports}} + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test{{.logic}}_{{.function}}(t *testing.T) { + c := config.Config{} + mockSvcCtx := svc.NewServiceContext(c) + // init mock service context here + + tests := []struct { + name string + ctx context.Context + setupMocks func() + {{if .hasRequest}}req *{{.requestType}}{{end}} + wantErr bool + checkResp func{{if .hasResponse}}{{.responseType}}{{else}}(err error){{end}} + }{ + { + name: "response error", + ctx: context.Background(), + setupMocks: func() { + // mock data for this test case + }, + {{if .hasRequest}}req: &{{.requestType}}{ + // TODO: init your request here + },{{end}} + wantErr: true, + checkResp: func{{if .hasResponse}}{{.responseType}}{{else}}(err error){{end}} { + // TODO: Add your check logic here + }, + }, + { + name: "successful", + ctx: context.Background(), + setupMocks: func() { + // Mock data for this test case + }, + {{if .hasRequest}}req: &{{.requestType}}{ + // TODO: init your request here + },{{end}} + wantErr: false, + checkResp: func{{if .hasResponse}}{{.responseType}}{{else}}(err error){{end}} { + // TODO: Add your check logic here + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupMocks() + l := New{{.logic}}(tt.ctx, mockSvcCtx) + {{if .hasResponse}}resp, {{end}}err := l.{{.function}}({{if .hasRequest}}tt.req{{end}}) + if tt.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + {{if .hasResponse}}assert.NotNil(t, resp){{end}} + } + tt.checkResp({{if .hasResponse}}resp, {{end}}err) + }) + } +} \ No newline at end of file diff --git a/tools/goctl/api/gogen/template.go b/tools/goctl/api/gogen/template.go index 580fef4f8b3f..1bb05dc5a4ed 100644 --- a/tools/goctl/api/gogen/template.go +++ b/tools/goctl/api/gogen/template.go @@ -12,7 +12,9 @@ const ( contextTemplateFile = "context.tpl" etcTemplateFile = "etc.tpl" handlerTemplateFile = "handler.tpl" + handlerTestTemplateFile = "handler_test.tpl" logicTemplateFile = "logic.tpl" + logicTestTemplateFile = "logic_test.tpl" mainTemplateFile = "main.tpl" middlewareImplementCodeFile = "middleware.tpl" routesTemplateFile = "routes.tpl" @@ -25,7 +27,9 @@ var templates = map[string]string{ contextTemplateFile: contextTemplate, etcTemplateFile: etcTemplate, handlerTemplateFile: handlerTemplate, + handlerTestTemplateFile: handlerTestTemplate, logicTemplateFile: logicTemplate, + logicTestTemplateFile: logicTestTemplate, mainTemplateFile: mainTemplate, middlewareImplementCodeFile: middlewareImplementCode, routesTemplateFile: routesTemplate, diff --git a/tools/goctl/api/new/newservice.go b/tools/goctl/api/new/newservice.go index 9241d8559afd..7780a240193f 100644 --- a/tools/goctl/api/new/newservice.go +++ b/tools/goctl/api/new/newservice.go @@ -83,6 +83,6 @@ func CreateServiceCommand(_ *cobra.Command, args []string) error { return err } - err = gogen.DoGenProject(apiFilePath, abs, VarStringStyle) + err = gogen.DoGenProject(apiFilePath, abs, VarStringStyle, false) return err } diff --git a/tools/goctl/internal/flags/default_en.json b/tools/goctl/internal/flags/default_en.json index d26ec292c5d8..ae5b486c387b 100644 --- a/tools/goctl/internal/flags/default_en.json +++ b/tools/goctl/internal/flags/default_en.json @@ -37,7 +37,8 @@ "home": "{{.global.home}}", "remote": "{{.global.remote}}", "branch": "{{.global.branch}}", - "style": "{{.global.style}}" + "style": "{{.global.style}}", + "test": "Generate test files" }, "new": { "short": "Fast create api service",