From ea95b56dc7c66263482cd3001f277ed1dfdbbb8f Mon Sep 17 00:00:00 2001 From: tonywang10101 Date: Tue, 17 Oct 2023 20:08:15 +0000 Subject: [PATCH] feat(model): Add github model files copy suport for LLM models --- pkg/utils/utils.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index cfd9335d..39ebd0e6 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -83,7 +83,8 @@ func findModelFiles(dir string) []string { _ = filepath.Walk(dir, func(path string, f os.FileInfo, err error) error { if strings.HasSuffix(f.Name(), ".onnx") || strings.HasSuffix(f.Name(), ".pt") || strings.HasSuffix(f.Name(), ".bias") || strings.HasSuffix(f.Name(), ".weight") || strings.HasSuffix(f.Name(), ".ini") || strings.HasSuffix(f.Name(), ".bin") || - strings.HasPrefix(f.Name(), "onnx__") { + strings.HasPrefix(f.Name(), "onnx__") || strings.HasSuffix(f.Name(), ".model") || strings.HasSuffix(f.Name(), ".json") || + strings.HasSuffix(f.Name(), ".safetensors") { modelPaths = append(modelPaths, path) } return nil @@ -132,7 +133,6 @@ type CacheModel struct { // GitHubClone clones a repository from GitHub. func GitHubClone(dir string, instanceConfig datamodel.GitHubModelConfiguration, isWithLargeFile bool) error { urlRepo := instanceConfig.Repository - // Check in the cache first. var cacheModels []CacheModel if config.Config.Cache.Model { @@ -185,7 +185,6 @@ func GitHubClone(dir string, instanceConfig datamodel.GitHubModelConfiguration, } defer f.Close() } - if isWithLargeFile { dvcPaths := findDVCPaths(dir) for _, dvcPath := range dvcPaths { @@ -252,12 +251,18 @@ func CopyModelFileToModelRepository(modelRepository string, dir string, tritonMo return err } // TODO: add general function to check if backend use fastertransformer, which has different model file structure - } else if modelSubNames[len(modelSubNames)-3] == "fastertransformer" && tritonSubNames[len(tritonSubNames)-2] == modelSubNames[len(modelSubNames)-3] { + } else if (modelSubNames[len(modelSubNames)-3] == "fastertransformer" || + modelSubNames[len(modelSubNames)-3] == "llama2_7b" || + modelSubNames[len(modelSubNames)-3] == "llamacode" || + modelSubNames[len(modelSubNames)-3] == "llama2_13b_chat" || + modelSubNames[len(modelSubNames)-3] == "mistral_7b" || + modelSubNames[len(modelSubNames)-3] == "mbt_7b") && + tritonSubNames[len(tritonSubNames)-2] == modelSubNames[len(modelSubNames)-3] { targetPath := fmt.Sprintf("%s/%s/%s/%s/", modelRepository, tritonModelName, modelSubNames[len(modelSubNames)-2], modelSubNames[len(modelSubNames)-1]) if err := os.MkdirAll(targetPath, os.ModePerm); err != nil { return err } - cmd := exec.Command("/bin/sh", "-c", fmt.Sprintf("cp %s %s/", modelPath, targetPath)) + cmd := exec.Command("/bin/sh", "-c", fmt.Sprintf("cp %s %s", modelPath, targetPath)) if err := cmd.Run(); err != nil { return err }