From 7afa4e32c839bc729a76f8d4d68fbfa0467cdd54 Mon Sep 17 00:00:00 2001 From: linhanyuan Date: Tue, 24 Oct 2023 14:40:36 +0800 Subject: [PATCH] feat: support preserve-structs for yaml-config of trimmer tool --- generator/golang/backend.go | 11 +------- tool/trimmer/main.go | 24 ++++------------ tool/trimmer/trim/trimmer.go | 33 ++++++++++++++++++++-- tool/trimmer/trim/trimmer_test.go | 47 +++++++++++++++++++++++++++++++ 4 files changed, 85 insertions(+), 30 deletions(-) diff --git a/generator/golang/backend.go b/generator/golang/backend.go index f6e3a328..e653c1db 100644 --- a/generator/golang/backend.go +++ b/generator/golang/backend.go @@ -17,7 +17,6 @@ package golang import ( "fmt" "go/format" - "os" "path/filepath" "strings" "text/template" @@ -87,15 +86,7 @@ func (g *GoBackend) Generate(req *plugin.Request, log backend.LogFunc) *plugin.R g.log = log g.prepareUtilities() if g.utils.Features().TrimIDL { - wd, _ := os.Getwd() - cfg := trim.ParseYamlConfig(wd) - var err error - if cfg == nil { - err = trim.TrimAST(req.AST, nil, false, nil) - } else { - err = trim.TrimAST(req.AST, cfg.Methods, !*cfg.Preserve, cfg.PreservedStructs) - } - + err := trim.TrimAST(&trim.TrimASTArg{Ast: req.AST, TrimMethods: nil, Preserve: nil}) if err != nil { g.log.Warn("trim error:", err.Error()) } diff --git a/tool/trimmer/main.go b/tool/trimmer/main.go index b79ef37a..fa97afac 100644 --- a/tool/trimmer/main.go +++ b/tool/trimmer/main.go @@ -53,13 +53,14 @@ func main() { os.Exit(0) } - preserve := true + var preserveInput *bool if a.Preserve != "" { - preserve, err = strconv.ParseBool(a.Preserve) + preserve, err := strconv.ParseBool(a.Preserve) if err != nil { help() os.Exit(2) } + preserveInput = &preserve } // parse file to ast @@ -73,23 +74,10 @@ func main() { check(err) check(semantic.ResolveSymbols(ast)) - // try parse yaml config - var preservedStructs []string - if wd, err := os.Getwd(); err == nil { - cfg := trim.ParseYamlConfig(wd) - if cfg != nil { - if len(a.Methods) == 0 && len(cfg.Methods) > 0 { - a.Methods = cfg.Methods - } - if a.Preserve == "" && !(*cfg.Preserve) { - preserve = false - } - preservedStructs = cfg.PreservedStructs - } - } - // trim ast - check(trim.TrimAST(ast, a.Methods, !preserve, preservedStructs)) + check(trim.TrimAST(&trim.TrimASTArg{ + Ast: ast, TrimMethods: a.Methods, Preserve: preserveInput, + })) // dump the trimmed ast to idl idl, err := dump.DumpIDL(ast) diff --git a/tool/trimmer/trim/trimmer.go b/tool/trimmer/trim/trimmer.go index 1bee9d7c..97fe1bdd 100644 --- a/tool/trimmer/trim/trimmer.go +++ b/tool/trimmer/trim/trimmer.go @@ -39,8 +39,37 @@ type Trimmer struct { preservedStructs []string } -// TrimAST trim the single AST, pass method names if -m specified -func TrimAST(ast *parser.Thrift, trimMethods []string, forceTrimming bool, preservedStructs []string) error { +type TrimASTArg struct { + Ast *parser.Thrift + TrimMethods []string + Preserve *bool +} + +// TrimAST parse the cfg and trim the single AST +func TrimAST(arg *TrimASTArg) error { + var preservedStructs []string + if wd, err := os.Getwd(); err == nil { + cfg := ParseYamlConfig(wd) + if cfg != nil { + if len(arg.TrimMethods) == 0 && len(cfg.Methods) > 0 { + arg.TrimMethods = cfg.Methods + } + if arg.Preserve == nil && !(*cfg.Preserve) { + preserve := false + arg.Preserve = &preserve + } + preservedStructs = cfg.PreservedStructs + } + } + forceTrim := false + if arg.Preserve != nil { + forceTrim = !*arg.Preserve + } + return doTrimAST(arg.Ast, arg.TrimMethods, forceTrim, preservedStructs) +} + +// doTrimAST trim the single AST, pass method names if -m specified +func doTrimAST(ast *parser.Thrift, trimMethods []string, forceTrimming bool, preservedStructs []string) error { trimmer, err := newTrimmer(nil, "") if err != nil { return err diff --git a/tool/trimmer/trim/trimmer_test.go b/tool/trimmer/trim/trimmer_test.go index 5687a2d1..0dd219ae 100644 --- a/tool/trimmer/trim/trimmer_test.go +++ b/tool/trimmer/trim/trimmer_test.go @@ -85,3 +85,50 @@ func TestInclude(t *testing.T) { test.Assert(t, len(ast.Includes) == 1) test.Assert(t, ast.Includes[0].Used == nil) } + +func TestTrimMethod(t *testing.T) { + filename := filepath.Join("..", "test_cases", "tests", "dir", "dir2", "test.thrift") + ast, err := parser.ParseFile(filename, nil, true) + check(err) + if path := parser.CircleDetect(ast); len(path) > 0 { + check(fmt.Errorf("found include circle:\n\t%s", path)) + } + checker := semantic.NewChecker(semantic.Options{FixWarnings: true}) + _, err = checker.CheckAll(ast) + check(err) + check(semantic.ResolveSymbols(ast)) + + methods := make([]string, 1) + methods[0] = "func1" + + err = TrimAST(&TrimASTArg{ + Ast: ast, + TrimMethods: methods, + Preserve: nil, + }) + check(err) + test.Assert(t, len(ast.Services[0].Functions) == 1) +} + +func TestPreserve(t *testing.T) { + filename := filepath.Join("..", "test_cases", "tests", "dir", "dir2", "test.thrift") + ast, err := parser.ParseFile(filename, nil, true) + check(err) + if path := parser.CircleDetect(ast); len(path) > 0 { + check(fmt.Errorf("found include circle:\n\t%s", path)) + } + checker := semantic.NewChecker(semantic.Options{FixWarnings: true}) + _, err = checker.CheckAll(ast) + check(err) + check(semantic.ResolveSymbols(ast)) + + preserve := false + + err = TrimAST(&TrimASTArg{ + Ast: ast, + TrimMethods: nil, + Preserve: &preserve, + }) + check(err) + test.Assert(t, len(ast.Structs) == 0) +}