From cd2472eeb3c5d4056b037f26341287ba1c00a484 Mon Sep 17 00:00:00 2001 From: Deng Ming Date: Tue, 16 Jul 2024 19:55:33 +0800 Subject: [PATCH] =?UTF-8?q?AI=20=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config.yaml | 4 + go.mod | 4 +- internal/ai/handlers.go | 68 ++++ internal/ai/internal/domain/gpt.go | 86 ++-- .../ai/internal/integration/module_test.go | 370 ++++++++++++------ .../ai/internal/integration/startup/wire.go | 65 ++- .../internal/integration/startup/wire_gen.go | 61 ++- internal/ai/internal/repository/config.go | 48 +++ internal/ai/internal/repository/credit.go | 38 ++ internal/ai/internal/repository/dao/config.go | 54 +++ internal/ai/internal/repository/dao/credit.go | 76 +--- internal/ai/internal/repository/dao/init.go | 5 +- internal/ai/internal/repository/dao/record.go | 75 ++++ internal/ai/internal/repository/log.go | 62 +++ internal/ai/internal/repository/repository.go | 75 ---- internal/ai/internal/service/gpt.go | 31 -- internal/ai/internal/service/gpt/gpt.go | 28 ++ .../gpt/handler/biz/composition_biz.go | 55 +++ .../service/gpt/handler/biz/facade.go | 33 ++ .../gpt/handler/biz/question_examine.go | 39 ++ .../internal/service/gpt/handler/biz/type.go} | 20 +- .../service/gpt/handler/config/builder.go | 48 +++ .../service/gpt/handler/credit/builder.go | 103 +++++ .../service/gpt/handler/gpt/zhipu/handler.go | 71 ++++ .../service/gpt/handler/log/builder.go | 45 +++ .../service/gpt/handler/mocks/handler.mock.go | 141 +++++++ .../service/gpt/handler/record/builder.go | 55 +++ .../ai/internal/service/gpt/handler/type.go | 22 ++ .../internal/service/handler/biz/handler.go | 34 -- .../service/handler/config/handler.go | 49 --- .../service/handler/credit/handler.go | 119 ------ internal/ai/internal/service/handler/error.go | 5 - .../service/handler/gpt/getter/polling.go | 25 -- .../handler/gpt/getter/polling_test.go | 65 --- .../service/handler/gpt/getter/type.go | 7 - .../internal/service/handler/gpt/handler.go | 68 ---- .../service/handler/gpt/mocks/gpt.mock.go | 80 ---- .../internal/service/handler/gpt/sdk/type.go | 11 - .../service/handler/gpt/sdk/zhipu/client.go | 31 -- .../handler/gpt/sdk/zhipu/client_test.go | 45 --- .../service/handler/gpt/sdk/zhipu/gpt.go | 42 -- .../internal/service/handler/log/handler.go | 40 -- .../service/handler/response/handler.go | 52 --- .../service/handler/simple/handler.go | 23 -- .../ai/internal/service/handler/simple/ioc.go | 27 -- internal/ai/internal/service/handler/type.go | 14 - internal/ai/mocks/gpt.mock.go | 43 +- internal/ai/type.go | 4 +- internal/ai/wire.go | 72 ++-- internal/ai/wire_gen.go | 69 ++-- internal/question/internal/domain/examine.go | 2 +- .../integration/admin_handler_test.go | 4 +- .../integration/admin_set_handler_test.go | 5 +- .../integration/examine_handler_test.go | 15 +- .../internal/integration/handler_test.go | 5 +- .../integration/knowledge_job_starter_test.go | 4 +- .../internal/integration/set_handler_test.go | 4 +- .../internal/integration/startup/wire.go | 4 + .../internal/integration/startup/wire_gen.go | 6 +- .../internal/repository/dao/examine_types.go | 2 +- internal/question/internal/service/examine.go | 3 +- internal/question/internal/web/examine_vo.go | 2 +- internal/question/wire.go | 4 + internal/question/wire_gen.go | 6 +- ioc/wire.go | 2 + ioc/wire_gen.go | 15 +- 66 files changed, 1568 insertions(+), 1222 deletions(-) create mode 100644 internal/ai/handlers.go create mode 100644 internal/ai/internal/repository/config.go create mode 100644 internal/ai/internal/repository/credit.go create mode 100644 internal/ai/internal/repository/dao/config.go create mode 100644 internal/ai/internal/repository/dao/record.go create mode 100644 internal/ai/internal/repository/log.go delete mode 100644 internal/ai/internal/repository/repository.go delete mode 100644 internal/ai/internal/service/gpt.go create mode 100644 internal/ai/internal/service/gpt/gpt.go create mode 100644 internal/ai/internal/service/gpt/handler/biz/composition_biz.go create mode 100644 internal/ai/internal/service/gpt/handler/biz/facade.go create mode 100644 internal/ai/internal/service/gpt/handler/biz/question_examine.go rename internal/{question/internal/service/mocks_ai.go => ai/internal/service/gpt/handler/biz/type.go} (67%) create mode 100644 internal/ai/internal/service/gpt/handler/config/builder.go create mode 100644 internal/ai/internal/service/gpt/handler/credit/builder.go create mode 100644 internal/ai/internal/service/gpt/handler/gpt/zhipu/handler.go create mode 100644 internal/ai/internal/service/gpt/handler/log/builder.go create mode 100644 internal/ai/internal/service/gpt/handler/mocks/handler.mock.go create mode 100644 internal/ai/internal/service/gpt/handler/record/builder.go create mode 100644 internal/ai/internal/service/gpt/handler/type.go delete mode 100644 internal/ai/internal/service/handler/biz/handler.go delete mode 100644 internal/ai/internal/service/handler/config/handler.go delete mode 100644 internal/ai/internal/service/handler/credit/handler.go delete mode 100644 internal/ai/internal/service/handler/error.go delete mode 100644 internal/ai/internal/service/handler/gpt/getter/polling.go delete mode 100644 internal/ai/internal/service/handler/gpt/getter/polling_test.go delete mode 100644 internal/ai/internal/service/handler/gpt/getter/type.go delete mode 100644 internal/ai/internal/service/handler/gpt/handler.go delete mode 100644 internal/ai/internal/service/handler/gpt/mocks/gpt.mock.go delete mode 100644 internal/ai/internal/service/handler/gpt/sdk/type.go delete mode 100644 internal/ai/internal/service/handler/gpt/sdk/zhipu/client.go delete mode 100644 internal/ai/internal/service/handler/gpt/sdk/zhipu/client_test.go delete mode 100644 internal/ai/internal/service/handler/gpt/sdk/zhipu/gpt.go delete mode 100644 internal/ai/internal/service/handler/log/handler.go delete mode 100644 internal/ai/internal/service/handler/response/handler.go delete mode 100644 internal/ai/internal/service/handler/simple/handler.go delete mode 100644 internal/ai/internal/service/handler/simple/ioc.go delete mode 100644 internal/ai/internal/service/handler/type.go diff --git a/config/config.yaml b/config/config.yaml index 78f66a65..d14cafc0 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -22,6 +22,10 @@ qywechat: chatRobot: webhookURL: "your/webhookURL" +zhipu: + apikey: '' + knowledgeId: '' + mysql: dsn: "webook:webook@tcp(mysql8:3306)/webook?charset=utf8mb4&collation=utf8mb4_general_ci&parseTime=True&loc=Local&timeout=1s&readTimeout=3s&writeTimeout=3s" diff --git a/go.mod b/go.mod index b3612f08..7dc36c02 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module github.com/ecodeclub/webook -go 1.22.2 - -toolchain go1.22.5 +go 1.22.5 require ( github.com/ecodeclub/ecache v0.0.0-20240111145855-75679834beca diff --git a/internal/ai/handlers.go b/internal/ai/handlers.go new file mode 100644 index 00000000..5ab5d92b --- /dev/null +++ b/internal/ai/handlers.go @@ -0,0 +1,68 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ai + +import ( + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/biz" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/config" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/credit" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/gpt/zhipu" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/log" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/record" + "github.com/gotomicro/ego/core/econf" +) + +func InitHandlerFacade(common []handler.Builder, zhipu *zhipu.Handler) *biz.FacadeHandler { + que := InitQuestionExamineHandler(common, zhipu) + return biz.NewHandler(map[string]handler.Handler{ + que.Biz(): que, + }) +} + +func InitZhipu() *zhipu.Handler { + type Config struct { + APIKey string `yaml:"apikey"` + Price float64 `yaml:"price"` + } + var cfg Config + err := econf.UnmarshalKey("zhipu", &cfg) + if err != nil { + panic(err) + } + h, err := zhipu.NewHandler(cfg.APIKey, cfg.Price) + if err != nil { + panic(err) + } + return h +} + +func InitQuestionExamineHandler( + common []handler.Builder, + // gpt 就是真正的出口 + gpt handler.Handler) *biz.CompositionHandler { + // log -> cfg -> credit -> record -> question_examine -> gpt + builder := biz.NewQuestionExamineBizHandlerBuilder() + common = append(common, builder) + res := biz.NewCombinedBizHandler("question_examine", common, gpt) + return res +} + +func InitCommonHandlers(log *log.HandlerBuilder, + cfg *config.HandlerBuilder, + credit *credit.HandlerBuilder, + record *record.HandlerBuilder) []handler.Builder { + return []handler.Builder{log, cfg, credit, record} +} diff --git a/internal/ai/internal/domain/gpt.go b/internal/ai/internal/domain/gpt.go index 6b161b09..08a2266d 100644 --- a/internal/ai/internal/domain/gpt.go +++ b/internal/ai/internal/domain/gpt.go @@ -1,72 +1,90 @@ package domain +const BizQuestionExamine = "question_examine" + type GPTRequest struct { Biz string Uid int64 // 请求id Tid string // 用户的输入 - Input []string - BizConfig GPTBiz + Input []string + // Prompt 将 input 和 PromptTemplate 结合之后生成的正儿八经的 Prompt + Prompt string + // 业务相关的配置 + Config BizConfig } type GPTResponse struct { // 花费的token - Tokens int + Tokens int64 // 花费的金额 Amount int64 // gpt的回答 Answer string } -type GPTBiz struct { - // 业务名称 - Biz string - // 每个token的钱 分为单位 - AmountPerToken float64 - // 每个token的积分 - CreditPerToken float64 - // 一次最多返回多少Tokens - MaxTokensPerTime int +type BizConfig struct { + // 允许的最长输入 + // 这里我们不用计算 token,只需要简单约束一下字符串长度就可以 + MaxInput int + // 使用的知识库 + KnowledgeId string + // 提示词。虽然这里只有一个 PromptTemplate 字段, + // 但是在部分业务里面,它是一个 json + // 这里一般使用 %s + // 后续考虑 key value 的形式 + PromptTemplate string } -type GPTCreditLog struct { +type GPTCredit struct { Id int64 Tid string Uid int64 Biz string Tokens int64 Amount int64 - Credit int64 - Status GPTLogStatus - Prompt string - Answer string + Status CreditStatus Ctime int64 Utime int64 } -type GPTLog struct { - Id int64 - Tid string - Uid int64 - Biz string - Tokens int64 - Amount int64 - Status GPTLogStatus - Prompt string - Answer string - Ctime int64 - Utime int64 +type GPTRecord struct { + Id int64 + Tid string + Uid int64 + Biz string + Tokens int64 + Amount int64 + Input []string + Status RecordStatus + KnowledgeId string + PromptTemplate string + Answer string + Ctime int64 + Utime int64 +} + +type CreditStatus uint8 + +const ( + CreditStatusProcessing CreditStatus = iota + CreditStatusSuccess + CreditStatusFailed +) + +func (g CreditStatus) ToUint8() uint8 { + return uint8(g) } -type GPTLogStatus uint8 +type RecordStatus uint8 -func (g GPTLogStatus) ToUint8() uint8 { +func (g RecordStatus) ToUint8() uint8 { return uint8(g) } const ( - ProcessingStatus GPTLogStatus = 0 - SuccessStatus GPTLogStatus = 1 - FailLogStatus GPTLogStatus = 2 + RecordStatusProcessing RecordStatus = 0 + RecordStatusSuccess RecordStatus = 1 + RecordStatusFailed RecordStatus = 2 ) diff --git a/internal/ai/internal/integration/module_test.go b/internal/ai/internal/integration/module_test.go index 902446e5..3f6e3442 100644 --- a/internal/ai/internal/integration/module_test.go +++ b/internal/ai/internal/integration/module_test.go @@ -4,15 +4,18 @@ package integration import ( "context" - "database/sql" "errors" "testing" + "time" + + "github.com/ecodeclub/ekit/sqlx" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt" + gptHandler "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler" + hdlmocks "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/mocks" "github.com/ecodeclub/webook/internal/ai/internal/domain" "github.com/ecodeclub/webook/internal/ai/internal/integration/startup" "github.com/ecodeclub/webook/internal/ai/internal/repository/dao" - gptmocks "github.com/ecodeclub/webook/internal/ai/internal/service/handler/gpt/mocks" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/gpt/sdk" "github.com/ecodeclub/webook/internal/credit" creditmocks "github.com/ecodeclub/webook/internal/credit/mocks" testioc "github.com/ecodeclub/webook/internal/test/ioc" @@ -23,192 +26,323 @@ import ( "go.uber.org/mock/gomock" ) +const knowledgeId = "abc" + type GptSuite struct { suite.Suite - logDao dao.GPTLogDAO + logDao dao.GPTRecordDAO db *egorm.Component + svc gpt.Service } func TestGptSuite(t *testing.T) { suite.Run(t, new(GptSuite)) } -func (g *GptSuite) SetupSuite() { +func (s *GptSuite) SetupSuite() { db := testioc.InitDB() - g.db = db - g.logDao = dao.NewGPTLogDAO(db) + s.db = db + err := dao.InitTables(db) + require.NoError(s.T(), err) + s.logDao = dao.NewGORMGPTLogDAO(db) + + // 先插入 BizConfig + now := time.Now().UnixMilli() + err = s.db.Create(&dao.BizConfig{ + Biz: domain.BizQuestionExamine, + MaxInput: 100, + PromptTemplate: "这是问题 %s,这是用户输入 %s", + KnowledgeId: knowledgeId, + Ctime: now, + Utime: now, + }).Error + assert.NoError(s.T(), err) } + func (s *GptSuite) TearDownSuite() { - err := s.db.Exec("DROP TABLE `gpt_logs`").Error - require.NoError(s.T(), err) - err = s.db.Exec("DROP TABLE `gpt_credit_logs`").Error + err := s.db.Exec("TRUNCATE TABLE `ai_biz_configs`").Error require.NoError(s.T(), err) } func (s *GptSuite) TearDownTest() { - err := s.db.Exec("TRUNCATE TABLE `gpt_logs`").Error + err := s.db.Exec("TRUNCATE TABLE `gpt_records`").Error require.NoError(s.T(), err) - err = s.db.Exec("TRUNCATE TABLE `gpt_credit_logs`").Error + err = s.db.Exec("TRUNCATE TABLE `gpt_credits`").Error require.NoError(s.T(), err) } -func (g *GptSuite) TestService() { - t := g.T() - tesecases := []struct { +func (s *GptSuite) TestService() { + t := s.T() + testCases := []struct { name string req domain.GPTRequest - newSvcFunc func(t *testing.T, ctrl *gomock.Controller) credit.Service - newAiFunc func(t *testing.T, ctrl *gomock.Controller) sdk.GPTSdk + before func(t *testing.T, ctrl *gomock.Controller) (gptHandler.Handler, credit.Service) assertFunc assert.ErrorAssertionFunc after func(t *testing.T, resp domain.GPTResponse) }{ { - name: "成功访问", + name: "八股文测试-成功", req: domain.GPTRequest{ - Biz: "simple", + Biz: domain.BizQuestionExamine, Uid: 123, Tid: "11", Input: []string{ - "nihao", + "问题1", + "用户输入1", }, }, - newAiFunc: func(t *testing.T, ctrl *gomock.Controller) sdk.GPTSdk { - mockAiSdk := gptmocks.NewMockGPTSdk(ctrl) - mockAiSdk.EXPECT().Invoke(gomock.Any(), gomock.Any()).Return(100, "aians", nil) - return mockAiSdk - }, - newSvcFunc: func(t *testing.T, ctrl *gomock.Controller) credit.Service { - mockCreditSvc := creditmocks.NewMockService(ctrl) - mockCreditSvc.EXPECT().GetCreditsByUID(gomock.Any(), int64(123)).Return(credit.Credit{ - TotalAmount: 1000, - LockedTotalAmount: 0, + assertFunc: assert.NoError, + before: func(t *testing.T, + ctrl *gomock.Controller) (gptHandler.Handler, credit.Service) { + gptHdl := hdlmocks.NewMockHandler(ctrl) + gptHdl.EXPECT().Handle(gomock.Any(), gomock.Any()). + Return(domain.GPTResponse{ + Tokens: 100, + Amount: 100, + Answer: "aians", + }, nil) + creditSvc := creditmocks.NewMockService(ctrl) + creditSvc.EXPECT().GetCreditsByUID(gomock.Any(), gomock.Any()).Return(credit.Credit{ + TotalAmount: 1000, }, nil) - mockCreditSvc.EXPECT().AddCredits(gomock.Any(), gomock.Any()).Do(func(ctx context.Context, cre credit.Credit) error { - if cre.Uid != 123 { - return errors.New("incorrect uid") - } - if len(cre.Logs) <= 0 { - return errors.New("incorrect logs") - } - l := cre.Logs[0] - if l.ChangeAmount != 100 || l.Biz != "ai-gpt" || l.BizId <= 0 || l.Desc == "" { - return errors.New("incorrect logs") - } - return nil - }) - return mockCreditSvc + creditSvc.EXPECT().AddCredits(gomock.Any(), gomock.Any()).Return(nil) + return gptHdl, creditSvc }, - assertFunc: assert.NoError, after: func(t *testing.T, resp domain.GPTResponse) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() // 校验response写入的内容是否正确 assert.Equal(t, domain.GPTResponse{ Tokens: 100, Amount: 100, Answer: "aians", }, resp) - logModel, err := g.logDao.FirstLog(context.Background(), 1) + var logModel dao.GPTRecord + err := s.db.WithContext(ctx).Where("id = ?", 1).First(&logModel).Error require.NoError(t, err) - g.assertLog(&dao.GptLog{ - Id: 1, - Tid: "11", - Uid: 123, - Biz: "simple", - Tokens: 100, - Amount: 100, - Status: 1, - Prompt: sql.NullString{ - Valid: true, - String: "[\"nihao\"]", - }, - Answer: sql.NullString{ - Valid: true, - String: "aians", + s.assertLog(dao.GPTRecord{ + Id: 1, + Tid: "11", + Uid: 123, + Biz: domain.BizQuestionExamine, + Tokens: 100, + Amount: 100, + KnowledgeId: knowledgeId, + Input: sqlx.JsonColumn[[]string]{ + Valid: true, + Val: []string{ + "问题1", + "用户输入1", + }, }, + Status: 1, + PromptTemplate: sqlx.NewNullString("这是问题 %s,这是用户输入 %s"), + Answer: sqlx.NewNullString("aians"), }, logModel) // 校验credit写入的内容是否正确 - creditLogModel, err := g.logDao.FirstCreditLog(context.Background(), 1) + var creditLogModel dao.GPTCredit + err = s.db.WithContext(ctx).Where("id = ?", 1).First(&creditLogModel).Error require.NoError(t, err) - g.assertCreditLog(&dao.GptCreditLog{ + s.assertCreditLog(dao.GPTCredit{ Id: 1, Tid: "11", Uid: 123, - Biz: "simple", - Tokens: 100, + Biz: domain.BizQuestionExamine, Amount: 100, - Credit: 100, Status: 1, - Prompt: sql.NullString{ - Valid: true, - String: "[\"nihao\"]", - }, - Answer: sql.NullString{ - Valid: true, - String: "aians", - }, }, creditLogModel) }, }, { - name: "积分不足扣款失败", + name: "积分不足", req: domain.GPTRequest{ - Biz: "simple", - Uid: 123, + Biz: domain.BizQuestionExamine, + Uid: 124, Tid: "11", Input: []string{ "nihao", }, }, - newAiFunc: func(t *testing.T, ctrl *gomock.Controller) sdk.GPTSdk { - mockAiSdk := gptmocks.NewMockGPTSdk(ctrl) - - return mockAiSdk - }, - newSvcFunc: func(t *testing.T, ctrl *gomock.Controller) credit.Service { - mockCreditSvc := creditmocks.NewMockService(ctrl) - mockCreditSvc.EXPECT().GetCreditsByUID(gomock.Any(), int64(123)).Return(credit.Credit{ - TotalAmount: 1, - LockedTotalAmount: 0, + before: func(t *testing.T, + ctrl *gomock.Controller) (gptHandler.Handler, credit.Service) { + gptHdl := hdlmocks.NewMockHandler(ctrl) + creditSvc := creditmocks.NewMockService(ctrl) + creditSvc.EXPECT().GetCreditsByUID(gomock.Any(), gomock.Any()).Return(credit.Credit{ + TotalAmount: 0, }, nil) - - return mockCreditSvc + return gptHdl, creditSvc + }, + after: func(t *testing.T, resp domain.GPTResponse) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + var logModel dao.GPTRecord + err := s.db.WithContext(ctx).Where("uid = ?", 124).First(&logModel).Error + require.NoError(t, err) + s.assertLog(dao.GPTRecord{ + Id: 1, + Tid: "11", + Uid: 124, + Biz: domain.BizQuestionExamine, + KnowledgeId: knowledgeId, + Input: sqlx.JsonColumn[[]string]{ + Valid: true, + Val: []string{ + "问题1", + "用户输入1", + }, + }, + Status: domain.RecordStatusFailed.ToUint8(), + PromptTemplate: sqlx.NewNullString("这是问题 %s,这是用户输入 %s"), + }, logModel) }, assertFunc: assert.Error, }, { - name: "creditSvc调用失败", + name: "GPT调用失败", req: domain.GPTRequest{ - Biz: "simple", - Uid: 123, + Biz: domain.BizQuestionExamine, + Uid: 125, Tid: "11", Input: []string{ - "nihao", + "问题1", + "用户输入1", }, }, - newAiFunc: func(t *testing.T, ctrl *gomock.Controller) sdk.GPTSdk { - mockAiSdk := gptmocks.NewMockGPTSdk(ctrl) - mockAiSdk.EXPECT().Invoke(gomock.Any(), gomock.Any()).Return(100, "aians", nil) - return mockAiSdk - }, - newSvcFunc: func(t *testing.T, ctrl *gomock.Controller) credit.Service { - mockCreditSvc := creditmocks.NewMockService(ctrl) - mockCreditSvc.EXPECT().GetCreditsByUID(gomock.Any(), int64(123)).Return(credit.Credit{ - TotalAmount: 1000, - LockedTotalAmount: 0, + before: func(t *testing.T, + ctrl *gomock.Controller) (gptHandler.Handler, credit.Service) { + gptHdl := hdlmocks.NewMockHandler(ctrl) + gptHdl.EXPECT().Handle(gomock.Any(), gomock.Any()). + Return(domain.GPTResponse{}, errors.New("调用失败")) + creditSvc := creditmocks.NewMockService(ctrl) + creditSvc.EXPECT().GetCreditsByUID(gomock.Any(), gomock.Any()).Return(credit.Credit{ + TotalAmount: 1000, }, nil) - mockCreditSvc.EXPECT().AddCredits(gomock.Any(), gomock.Any()).Return(errors.New("服务内部错误")) - return mockCreditSvc + return gptHdl, creditSvc + }, + after: func(t *testing.T, resp domain.GPTResponse) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + var logModel dao.GPTRecord + err := s.db.WithContext(ctx).Where("uid = ?", 125).First(&logModel).Error + require.NoError(t, err) + s.assertLog(dao.GPTRecord{ + Id: 1, + Tid: "11", + Uid: 125, + Biz: domain.BizQuestionExamine, + Tokens: 100, + Amount: 100, + KnowledgeId: knowledgeId, + Input: sqlx.JsonColumn[[]string]{ + Valid: true, + Val: []string{ + "问题1", + "用户输入1", + }, + }, + Status: domain.CreditStatusFailed.ToUint8(), + PromptTemplate: sqlx.NewNullString("这是问题 %s,这是用户输入 %s"), + Answer: sqlx.NewNullString("aians"), + }, logModel) + // 校验credit写入的内容是否正确 + var creditLogModel dao.GPTCredit + err = s.db.WithContext(ctx).Where("id = ?", 1).First(&creditLogModel).Error + require.NoError(t, err) + s.assertCreditLog(dao.GPTCredit{ + Id: 1, + Tid: "11", + Uid: 125, + Biz: domain.BizQuestionExamine, + Amount: 100, + Status: domain.RecordStatusFailed.ToUint8(), + }, creditLogModel) + }, + assertFunc: assert.Error, + }, + { + name: "积分足够,扣款失败", + req: domain.GPTRequest{ + Biz: domain.BizQuestionExamine, + Uid: 126, + Tid: "11", + Input: []string{ + "问题1", + "用户输入1", + }, }, assertFunc: assert.Error, + before: func(t *testing.T, + ctrl *gomock.Controller) (gptHandler.Handler, credit.Service) { + gptHdl := hdlmocks.NewMockHandler(ctrl) + gptHdl.EXPECT().Handle(gomock.Any(), gomock.Any()). + Return(domain.GPTResponse{ + Tokens: 100, + Amount: 100, + Answer: "aians", + }, nil) + creditSvc := creditmocks.NewMockService(ctrl) + creditSvc.EXPECT().GetCreditsByUID(gomock.Any(), gomock.Any()).Return(credit.Credit{ + TotalAmount: 1000, + }, nil) + creditSvc.EXPECT().AddCredits(gomock.Any(), gomock.Any()).Return(errors.New("mock db error")) + return gptHdl, creditSvc + }, + after: func(t *testing.T, resp domain.GPTResponse) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + // 校验response写入的内容是否正确 + assert.Equal(t, domain.GPTResponse{ + Tokens: 100, + Amount: 100, + Answer: "aians", + }, resp) + var logModel dao.GPTRecord + err := s.db.WithContext(ctx).Where("uid = ?", 126).First(&logModel).Error + require.NoError(t, err) + s.assertLog(dao.GPTRecord{ + Id: 1, + Tid: "11", + Uid: 126, + Biz: domain.BizQuestionExamine, + Tokens: 100, + Amount: 100, + KnowledgeId: knowledgeId, + Input: sqlx.JsonColumn[[]string]{ + Valid: true, + Val: []string{ + "问题1", + "用户输入1", + }, + }, + Status: domain.RecordStatusFailed.ToUint8(), + PromptTemplate: sqlx.NewNullString("这是问题 %s,这是用户输入 %s"), + Answer: sqlx.NewNullString("aians"), + }, logModel) + // 校验credit写入的内容是否正确 + var creditLogModel dao.GPTCredit + err = s.db.WithContext(ctx).Where("id = ?", 1).First(&creditLogModel).Error + require.NoError(t, err) + s.assertCreditLog(dao.GPTCredit{ + Id: 1, + Tid: "11", + Uid: 126, + Biz: domain.BizQuestionExamine, + Amount: 100, + Status: domain.CreditStatusFailed.ToUint8(), + }, creditLogModel) + }, }, } - for _, tc := range tesecases { + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { ctrl := gomock.NewController(t) - aiSdk := tc.newAiFunc(t, ctrl) - creditSvc := tc.newSvcFunc(t, ctrl) - mou, err := startup.InitModule(aiSdk, creditSvc) + defer ctrl.Finish() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + mockHdl, mockCredit := tc.before(t, ctrl) + mou, err := startup.InitModule(s.db, mockHdl, &credit.Module{Svc: mockCredit}) require.NoError(t, err) - resp, err := mou.Svc.Invoke(context.Background(), tc.req) + resp, err := mou.Svc.Invoke(ctx, tc.req) tc.assertFunc(t, err) if err != nil { return @@ -218,18 +352,18 @@ func (g *GptSuite) TestService() { } } -func (g *GptSuite) assertLog(wantLog *dao.GptLog, actual *dao.GptLog) { - require.True(g.T(), actual.Ctime != 0) - require.True(g.T(), actual.Utime != 0) +func (s *GptSuite) assertLog(wantLog dao.GPTRecord, actual dao.GPTRecord) { + require.True(s.T(), actual.Ctime != 0) + require.True(s.T(), actual.Utime != 0) actual.Ctime = 0 actual.Utime = 0 - assert.Equal(g.T(), wantLog, actual) + assert.Equal(s.T(), wantLog, actual) } -func (g *GptSuite) assertCreditLog(wantLog *dao.GptCreditLog, actual *dao.GptCreditLog) { - require.True(g.T(), actual.Ctime != 0) - require.True(g.T(), actual.Utime != 0) +func (s *GptSuite) assertCreditLog(wantLog dao.GPTCredit, actual dao.GPTCredit) { + require.True(s.T(), actual.Ctime != 0) + require.True(s.T(), actual.Utime != 0) actual.Ctime = 0 actual.Utime = 0 - assert.Equal(g.T(), wantLog, actual) + assert.Equal(s.T(), wantLog, actual) } diff --git a/internal/ai/internal/integration/startup/wire.go b/internal/ai/internal/integration/startup/wire.go index 32360bd7..9061d236 100644 --- a/internal/ai/internal/integration/startup/wire.go +++ b/internal/ai/internal/integration/startup/wire.go @@ -3,24 +3,71 @@ package startup import ( + "sync" + "github.com/ecodeclub/webook/internal/ai" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/biz" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/config" + aicredit "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/credit" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/log" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/record" + "github.com/ecodeclub/webook/internal/ai/internal/repository" - "github.com/ecodeclub/webook/internal/ai/internal/service" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/gpt/sdk" + "github.com/ecodeclub/webook/internal/ai/internal/repository/dao" "github.com/ecodeclub/webook/internal/credit" - testioc "github.com/ecodeclub/webook/internal/test/ioc" + "github.com/ego-component/egorm" "github.com/google/wire" + "gorm.io/gorm" ) -func InitModule( - aisdk sdk.GPTSdk, creditSvc credit.Service) (*ai.Module, error) { +func InitModule(db *egorm.Component, + hdl handler.Handler, + creditSvc *credit.Module) (*ai.Module, error) { wire.Build( - testioc.InitDB, - ai.InitGPTDAO, + gpt.NewGPTService, repository.NewGPTLogRepo, - ai.InitHandlers, - service.NewGPTService, + repository.NewGPTCreditLogRepo, + repository.NewCachedConfigRepository, + + InitGPTCreditLogDAO, + dao.NewGORMGPTLogDAO, + dao.NewGORMConfigDAO, + + config.NewBuilder, + log.NewHandler, + record.NewHandler, + aicredit.NewHandlerBuilder, + + ai.InitCommonHandlers, + InitHandlerFacade, + wire.Struct(new(ai.Module), "*"), + wire.FieldsOf(new(*credit.Module), "Svc"), ) return new(ai.Module), nil } + +func InitHandlerFacade(common []handler.Builder, gpt handler.Handler) *biz.FacadeHandler { + que := ai.InitQuestionExamineHandler(common, gpt) + return biz.NewHandler(map[string]handler.Handler{ + que.Biz(): que, + }) +} + +var daoOnce = sync.Once{} + +func InitTableOnce(db *gorm.DB) { + daoOnce.Do(func() { + err := dao.InitTables(db) + if err != nil { + panic(err) + } + }) +} + +func InitGPTCreditLogDAO(db *egorm.Component) dao.GPTCreditDAO { + InitTableOnce(db) + return dao.NewGPTCreditLogDAO(db) +} diff --git a/internal/ai/internal/integration/startup/wire_gen.go b/internal/ai/internal/integration/startup/wire_gen.go index 339102d7..061fcf13 100644 --- a/internal/ai/internal/integration/startup/wire_gen.go +++ b/internal/ai/internal/integration/startup/wire_gen.go @@ -1,30 +1,73 @@ // Code generated by Wire. DO NOT EDIT. -//go:generate go run -mod=mod github.com/google/wire/cmd/wire +//go:generate go run github.com/google/wire/cmd/wire //go:build !wireinject // +build !wireinject package startup import ( + "sync" + "github.com/ecodeclub/webook/internal/ai" "github.com/ecodeclub/webook/internal/ai/internal/repository" - service2 "github.com/ecodeclub/webook/internal/ai/internal/service" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/gpt/sdk" + "github.com/ecodeclub/webook/internal/ai/internal/repository/dao" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/biz" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/config" + credit2 "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/credit" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/log" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/record" "github.com/ecodeclub/webook/internal/credit" - testioc "github.com/ecodeclub/webook/internal/test/ioc" + "github.com/ego-component/egorm" + "gorm.io/gorm" ) // Injectors from wire.go: -func InitModule(aisdk sdk.GPTSdk, creditSvc credit.Service) (*ai.Module, error) { - db := testioc.InitDB() - gptLogDAO := ai.InitGPTDAO(db) +func InitModule(db *gorm.DB, hdl handler.Handler, creditSvc *credit.Module) (*ai.Module, error) { + handlerBuilder := log.NewHandler() + configDAO := dao.NewGORMConfigDAO(db) + configRepository := repository.NewCachedConfigRepository(configDAO) + configHandlerBuilder := config.NewBuilder(configRepository) + service := creditSvc.Svc + gptCreditLogDAO := InitGPTCreditLogDAO(db) + gptCreditLogRepo := repository.NewGPTCreditLogRepo(gptCreditLogDAO) + creditHandlerBuilder := credit2.NewHandlerBuilder(service, gptCreditLogRepo) + gptLogDAO := dao.NewGORMGPTLogDAO(db) gptLogRepo := repository.NewGPTLogRepo(gptLogDAO) - v := ai.InitHandlers(gptLogRepo, aisdk, creditSvc) - gptService := service2.NewGPTService(v) + recordHandlerBuilder := record.NewHandler(gptLogRepo) + v := ai.InitCommonHandlers(handlerBuilder, configHandlerBuilder, creditHandlerBuilder, recordHandlerBuilder) + facadeHandler := InitHandlerFacade(v, hdl) + gptService := gpt.NewGPTService(facadeHandler) module := &ai.Module{ Svc: gptService, } return module, nil } + +// wire.go: + +func InitHandlerFacade(common []handler.Builder, gpt2 handler.Handler) *biz.FacadeHandler { + que := ai.InitQuestionExamineHandler(common, gpt2) + return biz.NewHandler(map[string]handler.Handler{ + que.Biz(): que, + }) +} + +var daoOnce = sync.Once{} + +func InitTableOnce(db *gorm.DB) { + daoOnce.Do(func() { + err := dao.InitTables(db) + if err != nil { + panic(err) + } + }) +} + +func InitGPTCreditLogDAO(db *egorm.Component) dao.GPTCreditDAO { + InitTableOnce(db) + return dao.NewGPTCreditLogDAO(db) +} diff --git a/internal/ai/internal/repository/config.go b/internal/ai/internal/repository/config.go new file mode 100644 index 00000000..85d8c9ea --- /dev/null +++ b/internal/ai/internal/repository/config.go @@ -0,0 +1,48 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package repository + +import ( + "context" + + "github.com/ecodeclub/webook/internal/ai/internal/domain" + "github.com/ecodeclub/webook/internal/ai/internal/repository/dao" +) + +type ConfigRepository interface { + GetConfig(ctx context.Context, biz string) (domain.BizConfig, error) +} + +// CachedConfigRepository 这个是一定要搞缓存的 +// 后续性能瓶颈了再说 +type CachedConfigRepository struct { + dao dao.ConfigDAO +} + +func NewCachedConfigRepository(dao dao.ConfigDAO) ConfigRepository { + return &CachedConfigRepository{dao: dao} +} + +func (repo *CachedConfigRepository) GetConfig(ctx context.Context, biz string) (domain.BizConfig, error) { + res, err := repo.dao.GetConfig(ctx, biz) + if err != nil { + return domain.BizConfig{}, err + } + return domain.BizConfig{ + MaxInput: res.MaxInput, + PromptTemplate: res.PromptTemplate, + KnowledgeId: res.KnowledgeId, + }, nil +} diff --git a/internal/ai/internal/repository/credit.go b/internal/ai/internal/repository/credit.go new file mode 100644 index 00000000..b0944708 --- /dev/null +++ b/internal/ai/internal/repository/credit.go @@ -0,0 +1,38 @@ +package repository + +import ( + "context" + + "github.com/ecodeclub/webook/internal/ai/internal/domain" + "github.com/ecodeclub/webook/internal/ai/internal/repository/dao" +) + +type GPTCreditLogRepo interface { + SaveCredit(ctx context.Context, GPTDeductLog domain.GPTCredit) (int64, error) +} + +type gptCreditLogRepo struct { + logDao dao.GPTCreditDAO +} + +func NewGPTCreditLogRepo(logDao dao.GPTCreditDAO) GPTCreditLogRepo { + return &gptCreditLogRepo{ + logDao: logDao, + } +} + +func (g *gptCreditLogRepo) creditLogToEntity(gptLog domain.GPTCredit) dao.GPTCredit { + return dao.GPTCredit{ + Id: gptLog.Id, + Tid: gptLog.Tid, + Uid: gptLog.Uid, + Biz: gptLog.Biz, + Amount: gptLog.Amount, + Status: gptLog.Status.ToUint8(), + } +} + +func (g *gptCreditLogRepo) SaveCredit(ctx context.Context, gptDeductLog domain.GPTCredit) (int64, error) { + logEntity := g.creditLogToEntity(gptDeductLog) + return g.logDao.SaveCredit(ctx, logEntity) +} diff --git a/internal/ai/internal/repository/dao/config.go b/internal/ai/internal/repository/dao/config.go new file mode 100644 index 00000000..2214f675 --- /dev/null +++ b/internal/ai/internal/repository/dao/config.go @@ -0,0 +1,54 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dao + +import ( + "context" + + "github.com/ego-component/egorm" +) + +type ConfigDAO interface { + GetConfig(ctx context.Context, biz string) (BizConfig, error) +} + +type GORMConfigDAO struct { + db *egorm.Component +} + +func NewGORMConfigDAO(db *egorm.Component) ConfigDAO { + return &GORMConfigDAO{db: db} +} + +func (dao *GORMConfigDAO) GetConfig(ctx context.Context, biz string) (BizConfig, error) { + var res BizConfig + err := dao.db.WithContext(ctx).Where("biz = ?", biz).First(&res).Error + return res, err +} + +type BizConfig struct { + Id int64 `gorm:"primaryKey;autoIncrement;comment:AI biz 配置表ID"` + Biz string `gorm:"type:varchar(256);uniqueIndex;not null;comment:业务类型名"` + MaxInput int `gorm:"comment:最大输入长度"` + PromptTemplate string + KnowledgeId string `gorm:"type:varchar(256);not null;comment:使用的知识库 ID"` + // 其它字段按需添加 + Ctime int64 + Utime int64 +} + +func (c BizConfig) TableName() string { + return "ai_biz_configs" +} diff --git a/internal/ai/internal/repository/dao/credit.go b/internal/ai/internal/repository/dao/credit.go index 9f84f062..1168e691 100644 --- a/internal/ai/internal/repository/dao/credit.go +++ b/internal/ai/internal/repository/dao/credit.go @@ -2,89 +2,47 @@ package dao import ( "context" - "database/sql" "time" "github.com/ego-component/egorm" "gorm.io/gorm/clause" ) -// gpt扣分调用记录表 -type GptCreditLog struct { - Id int64 `gorm:"primaryKey;autoIncrement;comment:积分流水表自增ID"` - Tid string `gorm:"type:varchar(256);not null;comment:一次请求的Tid,可能有多次"` - Uid int64 `gorm:"not null;index:idx_user_id;comment:用户ID"` - Biz string `gorm:"type:varchar(256);not null;comment:业务类型名"` - Tokens int64 `gorm:"type:int;default:0;not null;comment:扣费token数"` - Amount int64 `gorm:"type:int;default:0;not null;comment:具体扣费的换算的钱,分为单位"` - Credit int64 `gorm:"type:int;default:0;not null;comment:具体扣费的积分"` - Status uint8 `gorm:"type:tinyint unsigned;not null;default:0;comment:调用状态 0=进行中 1=成功, 2=失败"` - Prompt sql.NullString `gorm:"type:text;comment:调用请求"` - Answer sql.NullString `gorm:"type:text;comment:gpt的回答"` +// GPTCredit gpt扣分调用记录表 +type GPTCredit struct { + Id int64 `gorm:"primaryKey;autoIncrement;comment:积分流水表自增ID"` + Tid string `gorm:"type:varchar(256);not null;comment:一次请求的Tid,可能有多次"` + Uid int64 `gorm:"not null;index:idx_user_id;comment:用户ID"` + Biz string `gorm:"type:varchar(256);not null;comment:业务类型名"` + Amount int64 `gorm:"type:int;default:0;not null;comment:具体扣费的换算的钱,分为单位"` + Status uint8 `gorm:"type:tinyint unsigned;not null;default:0;comment:调用状态 0=进行中 1=成功, 2=失败"` Ctime int64 Utime int64 } -type GptLog struct { - Id int64 `gorm:"primaryKey;autoIncrement;comment:积分流水表自增ID"` - Tid string `gorm:"type:varchar(256);not null;uniqueIndex:unq_tid;comment:一次请求的Tid只能有一次"` - Uid int64 `gorm:"not null;index:idx_user_id;comment:用户ID"` - Biz string `gorm:"type:varchar(256);not null;comment:业务类型名"` - Tokens int64 `gorm:"type:int;default:0;comment:扣费token数"` - Amount int64 `gorm:"type:int;default:0;comment:具体扣费的换算的钱,分为单位"` - Status uint8 `gorm:"type:tinyint unsigned;not null;default:1;comment:调用状态 1=成功, 2=失败"` - Prompt sql.NullString `gorm:"type:text;comment:调用请求"` - Answer sql.NullString `gorm:"type:text;comment:gpt的回答"` - Ctime int64 - Utime int64 +func (l GPTCredit) TableName() string { + return "gpt_credits" } -type GPTLogDAO interface { - SaveCreditLog(ctx context.Context, GPTDeductLog GptCreditLog) (int64, error) - SaveLog(ctx context.Context, GPTLog GptLog) (int64, error) - FirstCreditLog(ctx context.Context, id int64) (*GptCreditLog, error) - FirstLog(ctx context.Context, id int64) (*GptLog, error) +type GPTCreditDAO interface { + SaveCredit(ctx context.Context, GPTDeductLog GPTCredit) (int64, error) } -type gptLogDAO struct { +type GORMGPTCreditDAO struct { db *egorm.Component } -func NewGPTLogDAO(db *egorm.Component) GPTLogDAO { - return &gptLogDAO{ +func NewGPTCreditLogDAO(db *egorm.Component) GPTCreditDAO { + return &GORMGPTCreditDAO{ db: db, } } -func (g *gptLogDAO) FirstCreditLog(ctx context.Context, id int64) (*GptCreditLog, error) { - logModel := &GptCreditLog{} - err := g.db.WithContext(ctx).Model(&GptCreditLog{}).Where("id = ?", id).First(logModel).Error - return logModel, err -} - -func (g *gptLogDAO) FirstLog(ctx context.Context, id int64) (*GptLog, error) { - logModel := &GptLog{} - err := g.db.WithContext(ctx).Model(&GptLog{}).Where("id = ?", id).First(logModel).Error - return logModel, err -} - -func (g *gptLogDAO) SaveCreditLog(ctx context.Context, gptLog GptCreditLog) (int64, error) { - now := time.Now().UnixMilli() - gptLog.Ctime = now - gptLog.Utime = now - err := g.db.WithContext(ctx).Model(&GptCreditLog{}). - Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "id"}}, - DoUpdates: clause.AssignmentColumns([]string{"status", "utime"}), - }).Create(&gptLog).Error - return gptLog.Id, err -} - -func (g *gptLogDAO) SaveLog(ctx context.Context, gptLog GptLog) (int64, error) { +func (g *GORMGPTCreditDAO) SaveCredit(ctx context.Context, gptLog GPTCredit) (int64, error) { now := time.Now().UnixMilli() gptLog.Ctime = now gptLog.Utime = now - err := g.db.WithContext(ctx).Model(&GptLog{}). + err := g.db.WithContext(ctx).Model(&GPTCredit{}). Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "id"}}, DoUpdates: clause.AssignmentColumns([]string{"status", "utime"}), diff --git a/internal/ai/internal/repository/dao/init.go b/internal/ai/internal/repository/dao/init.go index f06a80e4..f33c1526 100644 --- a/internal/ai/internal/repository/dao/init.go +++ b/internal/ai/internal/repository/dao/init.go @@ -4,7 +4,8 @@ import "github.com/ego-component/egorm" func InitTables(db *egorm.Component) error { return db.AutoMigrate( - &GptCreditLog{}, - &GptLog{}, + &GPTCredit{}, + &GPTRecord{}, + &BizConfig{}, ) } diff --git a/internal/ai/internal/repository/dao/record.go b/internal/ai/internal/repository/dao/record.go new file mode 100644 index 00000000..6781e614 --- /dev/null +++ b/internal/ai/internal/repository/dao/record.go @@ -0,0 +1,75 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dao + +import ( + "context" + "database/sql" + "time" + + "github.com/ecodeclub/ekit/sqlx" + "github.com/ego-component/egorm" + "gorm.io/gorm/clause" +) + +type GPTRecordDAO interface { + Save(ctx context.Context, r GPTRecord) (int64, error) +} + +type GORMGPTLogDAO struct { + db *egorm.Component +} + +func NewGORMGPTLogDAO(db *egorm.Component) GPTRecordDAO { + return &GORMGPTLogDAO{db: db} +} + +func (g *GORMGPTLogDAO) Save(ctx context.Context, record GPTRecord) (int64, error) { + now := time.Now().UnixMilli() + record.Ctime = now + record.Utime = now + err := g.db.WithContext(ctx).Model(&GPTRecord{}). + Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "id"}}, + DoUpdates: clause.AssignmentColumns([]string{"status", "utime"}), + }).Create(&record).Error + return record.Id, err +} + +func (g *GORMGPTLogDAO) FirstLog(ctx context.Context, id int64) (*GPTRecord, error) { + logModel := &GPTRecord{} + err := g.db.WithContext(ctx).Model(&GPTRecord{}).Where("id = ?", id).First(logModel).Error + return logModel, err +} + +type GPTRecord struct { + Id int64 `gorm:"primaryKey;autoIncrement;comment:积分流水表自增ID"` + Tid string `gorm:"type:varchar(256);not null;uniqueIndex:unq_tid;comment:一次请求的Tid只能有一次"` + Uid int64 `gorm:"not null;index:idx_user_id;comment:用户ID"` + Biz string `gorm:"type:varchar(256);not null;comment:业务类型名"` + Tokens int64 `gorm:"type:int;default:0;comment:扣费token数"` + Amount int64 `gorm:"type:int;default:0;comment:具体扣费的换算的钱,分为单位"` + Status uint8 `gorm:"type:tinyint unsigned;not null;default:1;comment:调用状态 1=成功, 2=失败"` + Input sqlx.JsonColumn[[]string] `gorm:"type:text;comment:调用请求的参数"` + KnowledgeId string `gorm:"type:varchar(256);not null;comment:使用的知识库 ID"` + PromptTemplate sql.NullString `gorm:"type:text;comment:PromptTemplate 模板,加上请求参数构成一个完整的 prompt"` + Answer sql.NullString `gorm:"type:text;comment:gpt的回答"` + Ctime int64 + Utime int64 +} + +func (l GPTRecord) TableName() string { + return "gpt_records" +} diff --git a/internal/ai/internal/repository/log.go b/internal/ai/internal/repository/log.go new file mode 100644 index 00000000..3da05eec --- /dev/null +++ b/internal/ai/internal/repository/log.go @@ -0,0 +1,62 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package repository + +import ( + "context" + + "github.com/ecodeclub/ekit/sqlx" + "github.com/ecodeclub/webook/internal/ai/internal/domain" + "github.com/ecodeclub/webook/internal/ai/internal/repository/dao" +) + +type GPTLogRepo interface { + SaveLog(ctx context.Context, gptLog domain.GPTRecord) (int64, error) +} + +// 调用日志 +type gptLogDAO struct { + logDao dao.GPTRecordDAO +} + +func NewGPTLogRepo(logDao dao.GPTRecordDAO) GPTLogRepo { + return &gptLogDAO{ + logDao: logDao, + } +} + +func (g *gptLogDAO) SaveLog(ctx context.Context, gptLog domain.GPTRecord) (int64, error) { + logEntity := g.toEntity(gptLog) + return g.logDao.Save(ctx, logEntity) +} + +func (g *gptLogDAO) toEntity(r domain.GPTRecord) dao.GPTRecord { + return dao.GPTRecord{ + Id: r.Id, + Tid: r.Tid, + Uid: r.Uid, + Biz: r.Biz, + Tokens: r.Tokens, + Amount: r.Amount, + KnowledgeId: r.KnowledgeId, + Input: sqlx.JsonColumn[[]string]{ + Valid: true, + Val: r.Input, + }, + Status: r.Status.ToUint8(), + PromptTemplate: sqlx.NewNullString(r.PromptTemplate), + Answer: sqlx.NewNullString(r.Answer), + } +} diff --git a/internal/ai/internal/repository/repository.go b/internal/ai/internal/repository/repository.go deleted file mode 100644 index 6587a8ab..00000000 --- a/internal/ai/internal/repository/repository.go +++ /dev/null @@ -1,75 +0,0 @@ -package repository - -import ( - "context" - "database/sql" - - "github.com/ecodeclub/webook/internal/ai/internal/domain" - "github.com/ecodeclub/webook/internal/ai/internal/repository/dao" -) - -type GPTLogRepo interface { - SaveCreditLog(ctx context.Context, GPTDeductLog domain.GPTCreditLog) (int64, error) - SaveLog(ctx context.Context, gptLog domain.GPTLog) (int64, error) -} - -type gptLogRepo struct { - logDao dao.GPTLogDAO -} - -func NewGPTLogRepo(logDao dao.GPTLogDAO) GPTLogRepo { - return &gptLogRepo{ - logDao: logDao, - } -} - -func (g *gptLogRepo) creditLogToEntity(gptLog domain.GPTCreditLog) dao.GptCreditLog { - return dao.GptCreditLog{ - Id: gptLog.Id, - Tid: gptLog.Tid, - Uid: gptLog.Uid, - Biz: gptLog.Biz, - Tokens: gptLog.Tokens, - Amount: gptLog.Amount, - Credit: gptLog.Credit, - Status: gptLog.Status.ToUint8(), - Prompt: sql.NullString{ - Valid: true, - String: gptLog.Prompt, - }, - Answer: sql.NullString{ - Valid: true, - String: gptLog.Answer, - }, - } -} - -func (g *gptLogRepo) logToEntity(gptLog domain.GPTLog) dao.GptLog { - return dao.GptLog{ - Id: gptLog.Id, - Tid: gptLog.Tid, - Uid: gptLog.Uid, - Biz: gptLog.Biz, - Tokens: gptLog.Tokens, - Amount: gptLog.Amount, - Status: gptLog.Status.ToUint8(), - Prompt: sql.NullString{ - Valid: true, - String: gptLog.Prompt, - }, - Answer: sql.NullString{ - Valid: true, - String: gptLog.Answer, - }, - } -} - -func (g *gptLogRepo) SaveCreditLog(ctx context.Context, gptDeductLog domain.GPTCreditLog) (int64, error) { - logEntity := g.creditLogToEntity(gptDeductLog) - return g.logDao.SaveCreditLog(ctx, logEntity) -} - -func (g *gptLogRepo) SaveLog(ctx context.Context, gptLog domain.GPTLog) (int64, error) { - logEntity := g.logToEntity(gptLog) - return g.logDao.SaveLog(ctx, logEntity) -} diff --git a/internal/ai/internal/service/gpt.go b/internal/ai/internal/service/gpt.go deleted file mode 100644 index 612956bd..00000000 --- a/internal/ai/internal/service/gpt.go +++ /dev/null @@ -1,31 +0,0 @@ -package service - -import ( - "context" - - "github.com/ecodeclub/webook/internal/ai/internal/domain" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler" -) - -//go:generate mockgen -source=./gpt.go -destination=../../mocks/gpt.mock.go -package=aimocks -typed=true GPTService -type GPTService interface { - Invoke(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) -} - -type gptService struct { - handlerFunc handler.HandleFunc -} - -func NewGPTService(handlers []handler.GptHandler) GPTService { - var hdl handler.HandleFunc - for i := len(handlers) - 1; i >= 0; i-- { - hdl = handlers[i].Next(hdl) - } - return &gptService{ - handlerFunc: hdl, - } -} - -func (g *gptService) Invoke(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { - return g.handlerFunc(ctx, req) -} diff --git a/internal/ai/internal/service/gpt/gpt.go b/internal/ai/internal/service/gpt/gpt.go new file mode 100644 index 00000000..73930378 --- /dev/null +++ b/internal/ai/internal/service/gpt/gpt.go @@ -0,0 +1,28 @@ +package gpt + +import ( + "context" + + "github.com/ecodeclub/webook/internal/ai/internal/domain" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/biz" +) + +//go:generate mockgen -source=./gpt.go -destination=../../../mocks/gpt.mock.go -package=aimocks -typed=true Service +type Service interface { + Invoke(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) +} + +type gptService struct { + // 这边显示依赖 FacadeHandler + handler *biz.FacadeHandler +} + +func NewGPTService(facade *biz.FacadeHandler) Service { + return &gptService{ + handler: facade, + } +} + +func (g *gptService) Invoke(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { + return g.handler.Handle(ctx, req) +} diff --git a/internal/ai/internal/service/gpt/handler/biz/composition_biz.go b/internal/ai/internal/service/gpt/handler/biz/composition_biz.go new file mode 100644 index 00000000..8537e63f --- /dev/null +++ b/internal/ai/internal/service/gpt/handler/biz/composition_biz.go @@ -0,0 +1,55 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package biz + +import ( + "context" + + "github.com/ecodeclub/webook/internal/ai/internal/domain" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler" +) + +// CompositionHandler 通过组合 Handler 来完成某个业务 +// 后续该部分应该是动态计算的,通过结合配置来实现动态计算 +type CompositionHandler struct { + root handler.Handler + name string +} + +func (c *CompositionHandler) Handle(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { + return c.root.Handle(ctx, req) +} + +func (c *CompositionHandler) Name() string { + return c.name +} + +func (c *CompositionHandler) Biz() string { + return c.name +} + +func NewCombinedBizHandler(name string, + common []handler.Builder, + gpt handler.Handler) *CompositionHandler { + root := gpt + for i := len(common) - 1; i >= 0; i-- { + current := common[i] + root = current.Next(root) + } + return &CompositionHandler{ + root: root, + name: name, + } +} diff --git a/internal/ai/internal/service/gpt/handler/biz/facade.go b/internal/ai/internal/service/gpt/handler/biz/facade.go new file mode 100644 index 00000000..247dc78b --- /dev/null +++ b/internal/ai/internal/service/gpt/handler/biz/facade.go @@ -0,0 +1,33 @@ +package biz + +import ( + "context" + "errors" + "fmt" + + "github.com/ecodeclub/webook/internal/ai/internal/domain" + handler2 "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler" +) + +var ErrUnknownBiz = errors.New("未知的业务") + +// FacadeHandler 用于分发业务Biz +type FacadeHandler struct { + bizMap map[string]handler2.Handler +} + +func (f *FacadeHandler) Handle(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { + h, ok := f.bizMap[req.Biz] + if !ok { + return domain.GPTResponse{}, fmt.Errorf("%w biz: %s", ErrUnknownBiz, req.Biz) + } + return h.Handle(ctx, req) +} + +var _ handler2.Handler = &FacadeHandler{} + +func NewHandler(bizMap map[string]handler2.Handler) *FacadeHandler { + return &FacadeHandler{ + bizMap: bizMap, + } +} diff --git a/internal/ai/internal/service/gpt/handler/biz/question_examine.go b/internal/ai/internal/service/gpt/handler/biz/question_examine.go new file mode 100644 index 00000000..8844a451 --- /dev/null +++ b/internal/ai/internal/service/gpt/handler/biz/question_examine.go @@ -0,0 +1,39 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package biz + +import ( + "context" + "fmt" + + "github.com/ecodeclub/webook/internal/ai/internal/domain" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler" +) + +type QuestionExamineBizHandlerBuilder struct { +} + +func NewQuestionExamineBizHandlerBuilder() *QuestionExamineBizHandlerBuilder { + return &QuestionExamineBizHandlerBuilder{} +} + +func (h *QuestionExamineBizHandlerBuilder) Next(next handler.Handler) handler.Handler { + return handler.HandleFunc(func(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { + // 把 input 和 prompt 结合起来 + prompt := fmt.Sprintf(req.Config.PromptTemplate, req.Input[0], req.Input[1]) + req.Prompt = prompt + return next.Handle(ctx, req) + }) +} diff --git a/internal/question/internal/service/mocks_ai.go b/internal/ai/internal/service/gpt/handler/biz/type.go similarity index 67% rename from internal/question/internal/service/mocks_ai.go rename to internal/ai/internal/service/gpt/handler/biz/type.go index 966eb0e0..4c083cd2 100644 --- a/internal/question/internal/service/mocks_ai.go +++ b/internal/ai/internal/service/gpt/handler/biz/type.go @@ -12,21 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -package service +package biz import ( - "context" - - "github.com/ecodeclub/webook/internal/ai" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler" ) -type AiService struct { -} - -func (a *AiService) Invoke(ctx context.Context, req ai.GPTRequest) (ai.GPTResponse, error) { - return ai.GPTResponse{ - Tokens: int(req.Uid), - Amount: req.Uid, - Answer: "评分:15K", - }, nil +// GPTBizHandler 近似于标记接口,也就是用于区分专属于业务的,和通用的 Handler +type GPTBizHandler interface { + handler.Handler + // Biz 它处理的业务 + Biz() string } diff --git a/internal/ai/internal/service/gpt/handler/config/builder.go b/internal/ai/internal/service/gpt/handler/config/builder.go new file mode 100644 index 00000000..eaf2d0ee --- /dev/null +++ b/internal/ai/internal/service/gpt/handler/config/builder.go @@ -0,0 +1,48 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "context" + + "github.com/ecodeclub/webook/internal/ai/internal/domain" + "github.com/ecodeclub/webook/internal/ai/internal/repository" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler" +) + +// HandlerBuilder 改为从数据库中读取 +type HandlerBuilder struct { + repo repository.ConfigRepository +} + +func NewBuilder(repo repository.ConfigRepository) *HandlerBuilder { + return &HandlerBuilder{ + repo: repo, + } +} + +func (b *HandlerBuilder) Next(next handler.Handler) handler.Handler { + return handler.HandleFunc(func(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { + // 读取配置 + cfg, err := b.repo.GetConfig(ctx, req.Biz) + if err != nil { + return domain.GPTResponse{}, err + } + req.Config = cfg + return next.Handle(ctx, req) + }) +} + +var _ handler.Builder = &HandlerBuilder{} diff --git a/internal/ai/internal/service/gpt/handler/credit/builder.go b/internal/ai/internal/service/gpt/handler/credit/builder.go new file mode 100644 index 00000000..f45f1d14 --- /dev/null +++ b/internal/ai/internal/service/gpt/handler/credit/builder.go @@ -0,0 +1,103 @@ +package credit + +import ( + "context" + "errors" + "fmt" + + "github.com/ecodeclub/webook/internal/ai/internal/domain" + "github.com/ecodeclub/webook/internal/ai/internal/repository" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler" + "github.com/ecodeclub/webook/internal/credit" + uuid "github.com/lithammer/shortuuid/v4" +) + +type HandlerBuilder struct { + creditSvc credit.Service + logRepo repository.GPTCreditLogRepo +} + +func (h *HandlerBuilder) Name() string { + return "credit" +} + +var ( + ErrInsufficientCredit = errors.New("积分不足") +) + +func NewHandlerBuilder(creSvc credit.Service, repo repository.GPTCreditLogRepo) *HandlerBuilder { + return &HandlerBuilder{ + creditSvc: creSvc, + logRepo: repo, + } +} + +func (h *HandlerBuilder) Next(next handler.Handler) handler.Handler { + return handler.HandleFunc(func(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { + cre, err := h.creditSvc.GetCreditsByUID(ctx, req.Uid) + if err != nil { + return domain.GPTResponse{}, err + } + // 如果剩余的积分不足就返回积分不足 + ok := h.checkCredit(cre) + if !ok { + return domain.GPTResponse{}, fmt.Errorf("%w, 余额非正数,无法继续调用,用户 %d", + ErrInsufficientCredit, req.Uid) + } + + // 调用下层服务 + resp, err := next.Handle(ctx, req) + if err != nil { + return resp, err + } + + // 扣款 + id, err := h.logRepo.SaveCredit(ctx, h.newLog(req, resp)) + if err != nil { + return domain.GPTResponse{}, err + } + err = h.creditSvc.AddCredits(context.Background(), credit.Credit{ + Uid: req.Uid, + Logs: []credit.CreditLog{ + { + Key: uuid.New(), + Uid: req.Uid, + Biz: "ai-gpt", + BizId: id, + Desc: "ai-gpt服务", + }, + }, + }) + if err != nil { + _, _ = h.logRepo.SaveCredit(ctx, domain.GPTCredit{ + Id: id, + Status: domain.CreditStatusFailed, + }) + return domain.GPTResponse{}, err + } else { + _, err = h.logRepo.SaveCredit(ctx, domain.GPTCredit{ + Id: id, + Status: domain.CreditStatusSuccess, + }) + } + return resp, err + }) +} + +func (h *HandlerBuilder) newLog(req domain.GPTRequest, resp domain.GPTResponse) domain.GPTCredit { + return domain.GPTCredit{ + Tid: req.Tid, + Uid: req.Uid, + Biz: req.Biz, + Tokens: resp.Tokens, + Amount: resp.Amount, + Status: domain.CreditStatusProcessing, + } +} + +func (h *HandlerBuilder) checkCredit(cre credit.Credit) bool { + // 判断积分是否满足 + // 并不能用一次调用的最大 token 数量来算,因为要考虑用户可能最后只剩下一点点钱了, + // 这点钱不够最大数量,但是够一次普通调用 + return cre.TotalAmount > 0 +} diff --git a/internal/ai/internal/service/gpt/handler/gpt/zhipu/handler.go b/internal/ai/internal/service/gpt/handler/gpt/zhipu/handler.go new file mode 100644 index 00000000..1fc0f885 --- /dev/null +++ b/internal/ai/internal/service/gpt/handler/gpt/zhipu/handler.go @@ -0,0 +1,71 @@ +package zhipu + +import ( + "context" + "math" + + "github.com/ecodeclub/webook/internal/ai/internal/domain" + "github.com/yankeguo/zhipu" +) + +// Handler 如果后续有不同的实现,就提供不同的实现 +type Handler struct { + client *zhipu.Client + svc *zhipu.ChatCompletionService + // 价格和 model 进行绑定的 + price float64 +} + +func NewHandler(apikey string, + price float64) (*Handler, error) { + client, err := zhipu.NewClient(zhipu.WithAPIKey(apikey)) + if err != nil { + return nil, err + } + const model = "glm-4" + svc := client.ChatCompletion(model) + return &Handler{ + client: client, + // 后续可以做成可配置的 + svc: svc, + price: price, + }, err +} + +func (h *Handler) Name() string { + return "gpt" +} + +func (h *Handler) Handle(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { + // 这边它不会调用 next,因为它是最终的出口 + msg := h.newParams(req.Input) + completion, err := h.svc.AddTool(zhipu.ChatCompletionToolRetrieval{ + KnowledgeID: req.Config.KnowledgeId, + PromptTemplate: req.Config.PromptTemplate, + }).AddMessage(msg).Do(ctx) + if err != nil { + return domain.GPTResponse{}, err + } + tokens := completion.Usage.TotalTokens + // 现在的报价都是 N/1k token + // 而后向上取整 + amt := math.Ceil(float64(tokens) * h.price / 1000) + // 金额只有具体的模型才知道怎么算 + resp := domain.GPTResponse{ + Tokens: tokens, + Amount: int64(amt), + } + + if len(completion.Choices) > 0 { + resp.Answer = completion.Choices[0].Message.Content + } + return resp, nil +} + +func (h *Handler) newParams(inputs []string) zhipu.ChatCompletionMessage { + msg := inputs[0] + return zhipu.ChatCompletionMessage{ + Role: "user", + Content: msg, + } +} diff --git a/internal/ai/internal/service/gpt/handler/log/builder.go b/internal/ai/internal/service/gpt/handler/log/builder.go new file mode 100644 index 00000000..86fa609b --- /dev/null +++ b/internal/ai/internal/service/gpt/handler/log/builder.go @@ -0,0 +1,45 @@ +package log + +import ( + "context" + + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler" + + "github.com/ecodeclub/webook/internal/ai/internal/domain" + "github.com/gotomicro/ego/core/elog" +) + +type HandlerBuilder struct { + logger *elog.Component +} + +var _ handler.Builder = &HandlerBuilder{} + +func NewHandler() *HandlerBuilder { + return &HandlerBuilder{ + logger: elog.DefaultLogger, + } +} + +func (h *HandlerBuilder) Name() string { + return "log" +} + +func (h *HandlerBuilder) Next(next handler.Handler) handler.Handler { + return handler.HandleFunc(func(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { + logger := h.logger.With(elog.String("tid", req.Tid), + elog.Int64("uid", req.Uid), + elog.String("biz", req.Biz)) + // 记录请求 + logger.Info("请求 GPT") + resp, err := next.Handle(ctx, req) + if err != nil { + // 记录错误 + logger.Error("请求gpt服务失败", elog.FieldErr(err)) + return resp, err + } + // 记录响应 + logger.Info("请求gpt服务响应成功", elog.Int64("tokens", resp.Tokens)) + return resp, err + }) +} diff --git a/internal/ai/internal/service/gpt/handler/mocks/handler.mock.go b/internal/ai/internal/service/gpt/handler/mocks/handler.mock.go new file mode 100644 index 00000000..82a85913 --- /dev/null +++ b/internal/ai/internal/service/gpt/handler/mocks/handler.mock.go @@ -0,0 +1,141 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./type.go +// +// Generated by this command: +// +// mockgen -source=./type.go -destination=./mocks/handler.mock.go -package=hdlmocks -typed=true Handler +// +// Package hdlmocks is a generated GoMock package. +package hdlmocks + +import ( + context "context" + reflect "reflect" + + domain "github.com/ecodeclub/webook/internal/ai/internal/domain" + handler "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler" + gomock "go.uber.org/mock/gomock" +) + +// MockHandler is a mock of Handler interface. +type MockHandler struct { + ctrl *gomock.Controller + recorder *MockHandlerMockRecorder +} + +// MockHandlerMockRecorder is the mock recorder for MockHandler. +type MockHandlerMockRecorder struct { + mock *MockHandler +} + +// NewMockHandler creates a new mock instance. +func NewMockHandler(ctrl *gomock.Controller) *MockHandler { + mock := &MockHandler{ctrl: ctrl} + mock.recorder = &MockHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockHandler) EXPECT() *MockHandlerMockRecorder { + return m.recorder +} + +// Handle mocks base method. +func (m *MockHandler) Handle(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Handle", ctx, req) + ret0, _ := ret[0].(domain.GPTResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Handle indicates an expected call of Handle. +func (mr *MockHandlerMockRecorder) Handle(ctx, req any) *HandlerHandleCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Handle", reflect.TypeOf((*MockHandler)(nil).Handle), ctx, req) + return &HandlerHandleCall{Call: call} +} + +// HandlerHandleCall wrap *gomock.Call +type HandlerHandleCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *HandlerHandleCall) Return(arg0 domain.GPTResponse, arg1 error) *HandlerHandleCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *HandlerHandleCall) Do(f func(context.Context, domain.GPTRequest) (domain.GPTResponse, error)) *HandlerHandleCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *HandlerHandleCall) DoAndReturn(f func(context.Context, domain.GPTRequest) (domain.GPTResponse, error)) *HandlerHandleCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockBuilder is a mock of Builder interface. +type MockBuilder struct { + ctrl *gomock.Controller + recorder *MockBuilderMockRecorder +} + +// MockBuilderMockRecorder is the mock recorder for MockBuilder. +type MockBuilderMockRecorder struct { + mock *MockBuilder +} + +// NewMockBuilder creates a new mock instance. +func NewMockBuilder(ctrl *gomock.Controller) *MockBuilder { + mock := &MockBuilder{ctrl: ctrl} + mock.recorder = &MockBuilderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBuilder) EXPECT() *MockBuilderMockRecorder { + return m.recorder +} + +// Next mocks base method. +func (m *MockBuilder) Next(next handler.Handler) handler.Handler { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Next", next) + ret0, _ := ret[0].(handler.Handler) + return ret0 +} + +// Next indicates an expected call of Next. +func (mr *MockBuilderMockRecorder) Next(next any) *BuilderNextCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockBuilder)(nil).Next), next) + return &BuilderNextCall{Call: call} +} + +// BuilderNextCall wrap *gomock.Call +type BuilderNextCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *BuilderNextCall) Return(arg0 handler.Handler) *BuilderNextCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *BuilderNextCall) Do(f func(handler.Handler) handler.Handler) *BuilderNextCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *BuilderNextCall) DoAndReturn(f func(handler.Handler) handler.Handler) *BuilderNextCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/internal/ai/internal/service/gpt/handler/record/builder.go b/internal/ai/internal/service/gpt/handler/record/builder.go new file mode 100644 index 00000000..31950b20 --- /dev/null +++ b/internal/ai/internal/service/gpt/handler/record/builder.go @@ -0,0 +1,55 @@ +package record + +import ( + "context" + + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler" + "github.com/gotomicro/ego/core/elog" + + "github.com/ecodeclub/webook/internal/ai/internal/domain" + "github.com/ecodeclub/webook/internal/ai/internal/repository" +) + +type HandlerBuilder struct { + repo repository.GPTLogRepo + logger *elog.Component +} + +func NewHandler(repo repository.GPTLogRepo) *HandlerBuilder { + return &HandlerBuilder{ + repo: repo, + logger: elog.DefaultLogger, + } +} +func (h *HandlerBuilder) Name() string { + return "response" +} + +func (h *HandlerBuilder) Next(next handler.Handler) handler.Handler { + return handler.HandleFunc(func(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { + log := domain.GPTRecord{ + Tid: req.Tid, + Biz: req.Biz, + Uid: req.Uid, + Input: req.Input, + KnowledgeId: req.Config.KnowledgeId, + PromptTemplate: req.Config.PromptTemplate, + } + defer func() { + _, err1 := h.repo.SaveLog(ctx, log) + if err1 != nil { + h.logger.Error("保存 GPT 访问记录失败", elog.FieldErr(err1)) + } + }() + resp, err := next.Handle(ctx, req) + if err != nil { + log.Status = domain.RecordStatusFailed + return domain.GPTResponse{}, err + } + log.Tokens = resp.Tokens + log.Amount = resp.Amount + log.Status = domain.RecordStatusProcessing + log.Answer = resp.Answer + return resp, err + }) +} diff --git a/internal/ai/internal/service/gpt/handler/type.go b/internal/ai/internal/service/gpt/handler/type.go new file mode 100644 index 00000000..b2419861 --- /dev/null +++ b/internal/ai/internal/service/gpt/handler/type.go @@ -0,0 +1,22 @@ +package handler + +import ( + "context" + + "github.com/ecodeclub/webook/internal/ai/internal/domain" +) + +type HandleFunc func(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) + +func (f HandleFunc) Handle(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { + return f(ctx, req) +} + +//go:generate mockgen -source=./type.go -destination=./mocks/handler.mock.go -package=hdlmocks -typed=true Handler +type Handler interface { + Handle(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) +} + +type Builder interface { + Next(next Handler) Handler +} diff --git a/internal/ai/internal/service/handler/biz/handler.go b/internal/ai/internal/service/handler/biz/handler.go deleted file mode 100644 index 992ed94a..00000000 --- a/internal/ai/internal/service/handler/biz/handler.go +++ /dev/null @@ -1,34 +0,0 @@ -package biz - -import ( - "context" - - "github.com/ecodeclub/webook/internal/ai/internal/domain" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler" -) - -// 用于分发业务Biz -type FacadeHandler struct { - bizMap map[string]handler.GptHandler -} - -func (h *FacadeHandler) Name() string { - return "biz_facade" -} - -func NewHandler(bizMap map[string]handler.GptHandler) *FacadeHandler { - return &FacadeHandler{ - bizMap: bizMap, - } -} - -func (h *FacadeHandler) Next(next handler.HandleFunc) handler.HandleFunc { - return func(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { - handleFunc, ok := h.bizMap[req.Biz] - if !ok { - return domain.GPTResponse{}, handler.ErrUnknownBiz - } - nextFunc := handleFunc.Next(next) - return nextFunc(ctx, req) - } -} diff --git a/internal/ai/internal/service/handler/config/handler.go b/internal/ai/internal/service/handler/config/handler.go deleted file mode 100644 index ceffa219..00000000 --- a/internal/ai/internal/service/handler/config/handler.go +++ /dev/null @@ -1,49 +0,0 @@ -package config - -import ( - "context" - - "github.com/ecodeclub/webook/internal/ai/internal/domain" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler" -) - -type Handler struct { - configMap map[string]domain.GPTBiz -} - -func (h *Handler) Name() string { - return "config" -} - -func InitHandler() *Handler { - cfgs := []domain.GPTBiz{ - { - Biz: "simple", - AmountPerToken: 1, - CreditPerToken: 1, - MaxTokensPerTime: 1000, - }, - } - cfgMap := make(map[string]domain.GPTBiz, len(cfgs)) - for _, bizConfig := range cfgs { - cfgMap[bizConfig.Biz] = bizConfig - } - return NewHandler(cfgMap) -} - -func NewHandler(configMap map[string]domain.GPTBiz) *Handler { - return &Handler{ - configMap: configMap, - } -} - -func (h *Handler) Next(next handler.HandleFunc) handler.HandleFunc { - return func(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { - cfg, ok := h.configMap[req.Biz] - if !ok { - return domain.GPTResponse{}, handler.ErrUnknownBiz - } - req.BizConfig = cfg - return next(ctx, req) - } -} diff --git a/internal/ai/internal/service/handler/credit/handler.go b/internal/ai/internal/service/handler/credit/handler.go deleted file mode 100644 index bdac7b09..00000000 --- a/internal/ai/internal/service/handler/credit/handler.go +++ /dev/null @@ -1,119 +0,0 @@ -package credit - -import ( - "context" - "encoding/json" - "errors" - "math" - - "github.com/ecodeclub/webook/internal/ai/internal/domain" - "github.com/ecodeclub/webook/internal/ai/internal/repository" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler" - "github.com/ecodeclub/webook/internal/credit" - uuid "github.com/lithammer/shortuuid/v4" -) - -type Handler struct { - creditSvc credit.Service - logRepo repository.GPTLogRepo -} - -func (h *Handler) Name() string { - return "credit" -} - -var ( - ErrInsufficientCredit = errors.New("积分不足") -) - -func NewHandler(creSvc credit.Service, repo repository.GPTLogRepo) *Handler { - return &Handler{ - creditSvc: creSvc, - logRepo: repo, - } -} - -func (h *Handler) Next(next handler.HandleFunc) handler.HandleFunc { - return func(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { - bizConfig := req.BizConfig - cre, err := h.creditSvc.GetCreditsByUID(ctx, req.Uid) - if err != nil { - return domain.GPTResponse{}, err - } - // 如果剩余的积分不足就返回积分不足 - err = h.checkCredit(cre, bizConfig) - if err != nil { - return domain.GPTResponse{}, err - } - - // 调用下层服务 - resp, err := next(ctx, req) - if err != nil { - return resp, err - } - - // 扣款 - needCredit := h.roundUp(float64(resp.Tokens) * bizConfig.CreditPerToken) - needAmount := h.roundUp(float64(resp.Tokens) * bizConfig.AmountPerToken) - resp.Amount = int64(needAmount) - id, err := h.logRepo.SaveCreditLog(ctx, h.convertToDomain(needCredit, needAmount, req, resp)) - if err != nil { - return domain.GPTResponse{}, err - } - err = h.creditSvc.AddCredits(context.Background(), credit.Credit{ - Uid: req.Uid, - Logs: []credit.CreditLog{ - { - Key: uuid.New(), - Uid: req.Uid, - ChangeAmount: int64(-1 * needCredit), - Biz: "ai-gpt", - BizId: id, - Desc: "ai-gpt服务", - }, - }, - }) - if err != nil { - _, _ = h.logRepo.SaveCreditLog(ctx, domain.GPTCreditLog{ - Id: id, - Status: domain.FailLogStatus, - }) - return domain.GPTResponse{}, err - } else { - _, err = h.logRepo.SaveCreditLog(ctx, domain.GPTCreditLog{ - Id: id, - Status: domain.SuccessStatus, - }) - } - return resp, err - } -} - -func (h *Handler) convertToDomain(needCredit, needAmount int, req domain.GPTRequest, resp domain.GPTResponse) domain.GPTCreditLog { - prompt, _ := json.Marshal(req.Input) - return domain.GPTCreditLog{ - Tid: req.Tid, - Uid: req.Uid, - Biz: req.Biz, - Tokens: int64(resp.Tokens), - Amount: int64(needAmount), - Credit: int64(needCredit), - Status: domain.ProcessingStatus, - Prompt: string(prompt), - Answer: resp.Answer, - } -} - -func (h *Handler) checkCredit(cre credit.Credit, bizConfig domain.GPTBiz) error { - // 判断积分是否满足 - wantCre := h.roundUp(float64(bizConfig.MaxTokensPerTime) * bizConfig.CreditPerToken) - if wantCre > int(cre.TotalAmount) { - return ErrInsufficientCredit - } - return nil -} - -// 向上取整 -func (h *Handler) roundUp(val float64) int { - return int(math.Ceil(val)) -} diff --git a/internal/ai/internal/service/handler/error.go b/internal/ai/internal/service/handler/error.go deleted file mode 100644 index 2df54639..00000000 --- a/internal/ai/internal/service/handler/error.go +++ /dev/null @@ -1,5 +0,0 @@ -package handler - -import "errors" - -var ErrUnknownBiz = errors.New("未知的业务") diff --git a/internal/ai/internal/service/handler/gpt/getter/polling.go b/internal/ai/internal/service/handler/gpt/getter/polling.go deleted file mode 100644 index 828652a1..00000000 --- a/internal/ai/internal/service/handler/gpt/getter/polling.go +++ /dev/null @@ -1,25 +0,0 @@ -package getter - -import ( - "sync/atomic" - - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/gpt/sdk" -) - -// 轮询 -type PollingGetter struct { - Count int64 - Sdks []sdk.GPTSdk -} - -func NewPollingGetter(sdks []sdk.GPTSdk) *PollingGetter { - return &PollingGetter{ - Sdks: sdks, - } -} - -func (p *PollingGetter) GetSdk(biz string) (sdk.GPTSdk, error) { - res := p.Sdks[int(p.Count)%len(p.Sdks)] - atomic.AddInt64(&p.Count, 1) - return res, nil -} diff --git a/internal/ai/internal/service/handler/gpt/getter/polling_test.go b/internal/ai/internal/service/handler/gpt/getter/polling_test.go deleted file mode 100644 index 69775794..00000000 --- a/internal/ai/internal/service/handler/gpt/getter/polling_test.go +++ /dev/null @@ -1,65 +0,0 @@ -package getter - -import ( - "context" - "testing" - - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/gpt/sdk" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func Test_Polling(t *testing.T) { - testcases := []struct { - name string - sdks []sdk.GPTSdk - wantMockSdks []sdk.GPTSdk - wantErr error - }{ - { - name: "sdk轮询拿出", - sdks: []sdk.GPTSdk{ - &MockSdk{ - index: 0, - }, - &MockSdk{ - index: 1, - }, - &MockSdk{ - index: 2, - }, - }, - wantMockSdks: []sdk.GPTSdk{ - &MockSdk{ - index: 0, - }, - &MockSdk{ - index: 1, - }, - &MockSdk{ - index: 2, - }, - }, - }, - } - for _, tc := range testcases { - t.Run(tc.name, func(t *testing.T) { - getter := NewPollingGetter(tc.wantMockSdks) - for i := 0; i < len(tc.wantMockSdks); i++ { - gsdk, err := getter.GetSdk("xxx") - require.NoError(t, err) - index, _, err := gsdk.Invoke(context.Background(), []string{}) - require.NoError(t, err) - assert.Equal(t, index, int64(i)) - } - }) - } -} - -type MockSdk struct { - index int64 -} - -func (m *MockSdk) Invoke(ctx context.Context, input []string) (int64, string, error) { - return m.index, "", nil -} diff --git a/internal/ai/internal/service/handler/gpt/getter/type.go b/internal/ai/internal/service/handler/gpt/getter/type.go deleted file mode 100644 index dad0e947..00000000 --- a/internal/ai/internal/service/handler/gpt/getter/type.go +++ /dev/null @@ -1,7 +0,0 @@ -package getter - -import "github.com/ecodeclub/webook/internal/ai/internal/service/handler/gpt/sdk" - -type AiSdkGetter interface { - GetSdk(biz string) (sdk.GPTSdk, error) -} diff --git a/internal/ai/internal/service/handler/gpt/handler.go b/internal/ai/internal/service/handler/gpt/handler.go deleted file mode 100644 index 765195de..00000000 --- a/internal/ai/internal/service/handler/gpt/handler.go +++ /dev/null @@ -1,68 +0,0 @@ -package gpt - -import ( - "context" - "time" - - "github.com/ecodeclub/ekit/retry" - "github.com/ecodeclub/webook/internal/ai/internal/domain" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/gpt/getter" -) - -type Handler struct { - sdkGetter getter.AiSdkGetter - retryFunc *retry.ExponentialBackoffRetryStrategy -} - -const ( - defaultMinRetryInterval = 100 * time.Millisecond - defaultMaxRetryInterval = 10 * time.Second - defaultMaxRetryTimes = 10 -) - -func NewHandler(sdkGetter getter.AiSdkGetter) (*Handler, error) { - strategy, err := retry.NewExponentialBackoffRetryStrategy(defaultMinRetryInterval, defaultMaxRetryInterval, defaultMaxRetryTimes) - if err != nil { - return nil, err - } - return &Handler{ - sdkGetter: sdkGetter, - retryFunc: strategy, - }, nil -} - -func (h *Handler) Name() string { - return "gpt" -} - -func (h *Handler) Next(next handler.HandleFunc) handler.HandleFunc { - return func(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { - for { - gptSdk, err := h.sdkGetter.GetSdk(req.Biz) - if err != nil { - sleepTime, ok := h.retryFunc.Next() - if ok { - time.Sleep(sleepTime) - continue - } else { - return domain.GPTResponse{}, err - } - } - tokens, ans, err := gptSdk.Invoke(ctx, req.Input) - if err != nil { - sleepTime, ok := h.retryFunc.Next() - if ok { - time.Sleep(sleepTime) - continue - } else { - return domain.GPTResponse{}, err - } - } - return domain.GPTResponse{ - Tokens: int(tokens), - Answer: ans, - }, nil - } - } -} diff --git a/internal/ai/internal/service/handler/gpt/mocks/gpt.mock.go b/internal/ai/internal/service/handler/gpt/mocks/gpt.mock.go deleted file mode 100644 index a7b28705..00000000 --- a/internal/ai/internal/service/handler/gpt/mocks/gpt.mock.go +++ /dev/null @@ -1,80 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: ./type.go -// -// Generated by this command: -// -// mockgen -source=./type.go -destination=../mocks/gpt.mock.go -package=aimocks -typed=true GPTSdk -// - -// Package aimocks is a generated GoMock package. -package aimocks - -import ( - context "context" - reflect "reflect" - - gomock "go.uber.org/mock/gomock" -) - -// MockGPTSdk is a mock of GPTSdk interface. -type MockGPTSdk struct { - ctrl *gomock.Controller - recorder *MockGPTSdkMockRecorder -} - -// MockGPTSdkMockRecorder is the mock recorder for MockGPTSdk. -type MockGPTSdkMockRecorder struct { - mock *MockGPTSdk -} - -// NewMockGPTSdk creates a new mock instance. -func NewMockGPTSdk(ctrl *gomock.Controller) *MockGPTSdk { - mock := &MockGPTSdk{ctrl: ctrl} - mock.recorder = &MockGPTSdkMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockGPTSdk) EXPECT() *MockGPTSdkMockRecorder { - return m.recorder -} - -// Invoke mocks base method. -func (m *MockGPTSdk) Invoke(ctx context.Context, input []string) (int64, string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Invoke", ctx, input) - ret0, _ := ret[0].(int64) - ret1, _ := ret[1].(string) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// Invoke indicates an expected call of Invoke. -func (mr *MockGPTSdkMockRecorder) Invoke(ctx, input any) *MockGPTSdkInvokeCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Invoke", reflect.TypeOf((*MockGPTSdk)(nil).Invoke), ctx, input) - return &MockGPTSdkInvokeCall{Call: call} -} - -// MockGPTSdkInvokeCall wrap *gomock.Call -type MockGPTSdkInvokeCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockGPTSdkInvokeCall) Return(arg0 int64, arg1 string, arg2 error) *MockGPTSdkInvokeCall { - c.Call = c.Call.Return(arg0, arg1, arg2) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockGPTSdkInvokeCall) Do(f func(context.Context, []string) (int64, string, error)) *MockGPTSdkInvokeCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockGPTSdkInvokeCall) DoAndReturn(f func(context.Context, []string) (int64, string, error)) *MockGPTSdkInvokeCall { - c.Call = c.Call.DoAndReturn(f) - return c -} diff --git a/internal/ai/internal/service/handler/gpt/sdk/type.go b/internal/ai/internal/service/handler/gpt/sdk/type.go deleted file mode 100644 index 50db3e28..00000000 --- a/internal/ai/internal/service/handler/gpt/sdk/type.go +++ /dev/null @@ -1,11 +0,0 @@ -package sdk - -import "context" - -//go:generate mockgen -source=./type.go -destination=../mocks/gpt.mock.go -package=aimocks -typed=true GPTSdk - -// 各个ai sdk统一的抽象 -type GPTSdk interface { - // 返回值 第一个是token数,第二个为返回内容 - Invoke(ctx context.Context, input []string) (int64, string, error) -} diff --git a/internal/ai/internal/service/handler/gpt/sdk/zhipu/client.go b/internal/ai/internal/service/handler/gpt/sdk/zhipu/client.go deleted file mode 100644 index cc9709ce..00000000 --- a/internal/ai/internal/service/handler/gpt/sdk/zhipu/client.go +++ /dev/null @@ -1,31 +0,0 @@ -package zhipu - -import ( - "context" - - "github.com/yankeguo/zhipu" -) - -type Client struct { - client *zhipu.Client - apiKey string - knowledgeId string -} - -func NewClient(apikey, knowledgeId string) (*Client, error) { - client, err := zhipu.NewClient(zhipu.WithAPIKey(apikey)) - if err != nil { - return nil, err - } - return &Client{ - client: client, - apiKey: apikey, - knowledgeId: knowledgeId, - }, nil -} - -func (c *Client) ChatCompletion(ctx context.Context, msg zhipu.ChatCompletionMessage) (zhipu.ChatCompletionResponse, error) { - return c.client.ChatCompletion("glm-4").AddTool(zhipu.ChatCompletionToolRetrieval{ - KnowledgeID: c.knowledgeId, - }).AddMessage(msg).Do(ctx) -} diff --git a/internal/ai/internal/service/handler/gpt/sdk/zhipu/client_test.go b/internal/ai/internal/service/handler/gpt/sdk/zhipu/client_test.go deleted file mode 100644 index 6b306e98..00000000 --- a/internal/ai/internal/service/handler/gpt/sdk/zhipu/client_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package zhipu - -//func Test_Example(t *testing.T) { -// c, err := NewClient("apikey", "knowageId") -// require.NoError(t, err) -// postResponse, err := c.ChatCompletion(context.Background(), zhipu.ChatCompletionMessage{ -// Role: "user", -// Content: `现在你扮演一个技术面试官,你会专注于考察程序员的知识是否丰富。 -// -//接下来我会提供给你一个面试问题和一个候选人的回答。 -// -//你需要根据面试问题从知识库中找到该问题的答案,并且提取出来这个问题的答案的关键点,作为标准答案的关键点。 -// -//你将按照 15K,25K 和 35K 来分别列出关键点。 -// -//而后你要提取出来候选人回答的关键点。 -// -//接着你会对比标准答案和候选人回答,进行评分。评分分成三级:15K、25K、35K。 -// -//评分标准是: -//1. 如果候选人回答出来 15K 部分的所有关键点,那么评分至少是 15K; -//2. 在 1 的基础上,如果标准答案不存在 25K 部分的回答,那么候选人评分至少是 25K;如果候选人回答出来了 25K 部分的所有关键点,那么评分至少是 25K; -//3. 在 2 的基础上,如果标准答案不存在 25K 部分的回答,那么候选人评分至少是 35K;如果候选人回答出来了在 35K 部分的所有关键点,那么评分是 35K; -// -//你不需要输出任何你提取的关键点,而是只输出以下内容: -//1. 评分: 输出 15K、25K、35K 之一。 -//2. 遗漏的关键点,按照 15K,25K 和 35K 分别列出来 -// -// -//这是问题: -//Go 使用的三色标记法是如何运行的? -// -//这是回答: -//三色标记法的的原作原理还是比较简单的: -// -//1. 在初始状态,所有的对象都是白色; -//2. 逐一扫描这些对象,以及这些对象指向的对象。在扫描的时候,这个对象就是灰色; -//3. 把对象的所有的儿子都扫描完毕,这个对象就被标记为黑色 -//4. 从实现上来说,这个过程很类似于树的广度优先遍历,或者有向图的遍历。`, -// }) -// require.NoError(t, err) -// v, err := json.Marshal(postResponse) -// require.NoError(t, err) -// log.Println(string(v)) -//} diff --git a/internal/ai/internal/service/handler/gpt/sdk/zhipu/gpt.go b/internal/ai/internal/service/handler/gpt/sdk/zhipu/gpt.go deleted file mode 100644 index 775d51f6..00000000 --- a/internal/ai/internal/service/handler/gpt/sdk/zhipu/gpt.go +++ /dev/null @@ -1,42 +0,0 @@ -package zhipu - -import ( - "context" - - "github.com/yankeguo/zhipu" -) - -// 智谱 - -type GPT struct { - sdk *Client -} - -func NewGpt(sdk *Client) *GPT { - return &GPT{ - sdk: sdk, - } -} - -func (g *GPT) Invoke(ctx context.Context, input []string) (int64, string, error) { - params := g.newParams(input) - resp, err := g.sdk.ChatCompletion(ctx, params) - if err != nil { - return 0, "", err - } - tokens := resp.Usage.TotalTokens - // 默认是一个问题 - var ans string - if len(resp.Choices) > 0 { - ans = resp.Choices[0].Message.Content - } - return tokens, ans, nil -} - -func (g *GPT) newParams(inputs []string) zhipu.ChatCompletionMessage { - msg := inputs[0] - return zhipu.ChatCompletionMessage{ - Role: "user", - Content: msg, - } -} diff --git a/internal/ai/internal/service/handler/log/handler.go b/internal/ai/internal/service/handler/log/handler.go deleted file mode 100644 index f5ad5c31..00000000 --- a/internal/ai/internal/service/handler/log/handler.go +++ /dev/null @@ -1,40 +0,0 @@ -package log - -import ( - "context" - "fmt" - - "github.com/ecodeclub/webook/internal/ai/internal/domain" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler" - "github.com/gotomicro/ego/core/elog" -) - -type Handler struct { - logger *elog.Component -} - -func NewHandler() *Handler { - return &Handler{ - logger: elog.DefaultLogger, - } -} - -func (h *Handler) Name() string { - return "log" -} - -func (h *Handler) Next(next handler.HandleFunc) handler.HandleFunc { - return func(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { - // 记录请求 - h.logger.Info(fmt.Sprintf("请求gpt服务请求id为 %s", req.Tid), elog.FieldExtMessage(req)) - resp, err := next(ctx, req) - if err != nil { - // 记录错误 - h.logger.Error(fmt.Sprintf("请求gpt服务失败请求id为 %s", req.Tid), elog.FieldErr(err)) - return resp, err - } - // 记录响应 - h.logger.Info(fmt.Sprintf("请求gpt服务请求id为 %s", req.Tid), elog.FieldExtMessage(resp)) - return resp, err - } -} diff --git a/internal/ai/internal/service/handler/response/handler.go b/internal/ai/internal/service/handler/response/handler.go deleted file mode 100644 index 45ebabda..00000000 --- a/internal/ai/internal/service/handler/response/handler.go +++ /dev/null @@ -1,52 +0,0 @@ -package response - -import ( - "context" - "encoding/json" - - "github.com/ecodeclub/webook/internal/ai/internal/domain" - "github.com/ecodeclub/webook/internal/ai/internal/repository" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler" -) - -type Handler struct { - repo repository.GPTLogRepo -} - -func NewHandler(repo repository.GPTLogRepo) *Handler { - return &Handler{ - repo: repo, - } -} -func (h *Handler) Name() string { - return "response" -} - -func (h *Handler) Next(next handler.HandleFunc) handler.HandleFunc { - return func(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { - resp, err := next(ctx, req) - msgByte, _ := json.Marshal(req.Input) - msg := string(msgByte) - if err != nil { - _, _ = h.repo.SaveLog(ctx, domain.GPTLog{ - Tid: req.Tid, - Biz: req.Biz, - Uid: req.Uid, - Prompt: msg, - Status: domain.FailLogStatus, - }) - return domain.GPTResponse{}, err - } - _, err = h.repo.SaveLog(ctx, domain.GPTLog{ - Tid: req.Tid, - Uid: req.Uid, - Biz: req.Biz, - Tokens: int64(resp.Tokens), - Amount: resp.Amount, - Status: domain.ProcessingStatus, - Prompt: msg, - Answer: resp.Answer, - }) - return resp, err - } -} diff --git a/internal/ai/internal/service/handler/simple/handler.go b/internal/ai/internal/service/handler/simple/handler.go deleted file mode 100644 index d87004ec..00000000 --- a/internal/ai/internal/service/handler/simple/handler.go +++ /dev/null @@ -1,23 +0,0 @@ -package simple - -import ( - "context" - - "github.com/ecodeclub/webook/internal/ai/internal/domain" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler" -) - -// 最简业务handler -type Handler struct { - handlerFunc handler.HandleFunc -} - -func (h *Handler) Name() string { - return "simple" -} - -func (h *Handler) Next(next handler.HandleFunc) handler.HandleFunc { - return func(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { - return h.handlerFunc(ctx, req) - } -} diff --git a/internal/ai/internal/service/handler/simple/ioc.go b/internal/ai/internal/service/handler/simple/ioc.go deleted file mode 100644 index f2f720ce..00000000 --- a/internal/ai/internal/service/handler/simple/ioc.go +++ /dev/null @@ -1,27 +0,0 @@ -package simple - -import ( - "github.com/ecodeclub/webook/internal/ai/internal/service/handler" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/credit" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/gpt" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/log" -) - -func InitHandler( - logHandler *log.Handler, - creditHandler *credit.Handler, - gptHandler *gpt.Handler, -) *Handler { - handlers := []handler.GptHandler{ - logHandler, - creditHandler, - gptHandler, - } - var h handler.HandleFunc - for i := len(handlers) - 1; i >= 0; i-- { - h = handlers[i].Next(h) - } - return &Handler{ - handlerFunc: h, - } -} diff --git a/internal/ai/internal/service/handler/type.go b/internal/ai/internal/service/handler/type.go deleted file mode 100644 index 57566e94..00000000 --- a/internal/ai/internal/service/handler/type.go +++ /dev/null @@ -1,14 +0,0 @@ -package handler - -import ( - "context" - - "github.com/ecodeclub/webook/internal/ai/internal/domain" -) - -type HandleFunc func(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) - -type GptHandler interface { - Name() string - Next(next HandleFunc) HandleFunc -} diff --git a/internal/ai/mocks/gpt.mock.go b/internal/ai/mocks/gpt.mock.go index 73f1f25c..83d12882 100644 --- a/internal/ai/mocks/gpt.mock.go +++ b/internal/ai/mocks/gpt.mock.go @@ -3,9 +3,8 @@ // // Generated by this command: // -// mockgen -source=./gpt.go -destination=../../mocks/gpt.mock.go -package=aimocks -typed=true GPTService +// mockgen -source=./gpt.go -destination=../../../mocks/gpt.mock.go -package=aimocks -typed=true Service // - // Package aimocks is a generated GoMock package. package aimocks @@ -17,31 +16,31 @@ import ( gomock "go.uber.org/mock/gomock" ) -// MockGPTService is a mock of GPTService interface. -type MockGPTService struct { +// MockService is a mock of Service interface. +type MockService struct { ctrl *gomock.Controller - recorder *MockGPTServiceMockRecorder + recorder *MockServiceMockRecorder } -// MockGPTServiceMockRecorder is the mock recorder for MockGPTService. -type MockGPTServiceMockRecorder struct { - mock *MockGPTService +// MockServiceMockRecorder is the mock recorder for MockService. +type MockServiceMockRecorder struct { + mock *MockService } -// NewMockGPTService creates a new mock instance. -func NewMockGPTService(ctrl *gomock.Controller) *MockGPTService { - mock := &MockGPTService{ctrl: ctrl} - mock.recorder = &MockGPTServiceMockRecorder{mock} +// NewMockService creates a new mock instance. +func NewMockService(ctrl *gomock.Controller) *MockService { + mock := &MockService{ctrl: ctrl} + mock.recorder = &MockServiceMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockGPTService) EXPECT() *MockGPTServiceMockRecorder { +func (m *MockService) EXPECT() *MockServiceMockRecorder { return m.recorder } // Invoke mocks base method. -func (m *MockGPTService) Invoke(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { +func (m *MockService) Invoke(ctx context.Context, req domain.GPTRequest) (domain.GPTResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Invoke", ctx, req) ret0, _ := ret[0].(domain.GPTResponse) @@ -50,31 +49,31 @@ func (m *MockGPTService) Invoke(ctx context.Context, req domain.GPTRequest) (dom } // Invoke indicates an expected call of Invoke. -func (mr *MockGPTServiceMockRecorder) Invoke(ctx, req any) *MockGPTServiceInvokeCall { +func (mr *MockServiceMockRecorder) Invoke(ctx, req any) *ServiceInvokeCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Invoke", reflect.TypeOf((*MockGPTService)(nil).Invoke), ctx, req) - return &MockGPTServiceInvokeCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Invoke", reflect.TypeOf((*MockService)(nil).Invoke), ctx, req) + return &ServiceInvokeCall{Call: call} } -// MockGPTServiceInvokeCall wrap *gomock.Call -type MockGPTServiceInvokeCall struct { +// ServiceInvokeCall wrap *gomock.Call +type ServiceInvokeCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockGPTServiceInvokeCall) Return(arg0 domain.GPTResponse, arg1 error) *MockGPTServiceInvokeCall { +func (c *ServiceInvokeCall) Return(arg0 domain.GPTResponse, arg1 error) *ServiceInvokeCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockGPTServiceInvokeCall) Do(f func(context.Context, domain.GPTRequest) (domain.GPTResponse, error)) *MockGPTServiceInvokeCall { +func (c *ServiceInvokeCall) Do(f func(context.Context, domain.GPTRequest) (domain.GPTResponse, error)) *ServiceInvokeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockGPTServiceInvokeCall) DoAndReturn(f func(context.Context, domain.GPTRequest) (domain.GPTResponse, error)) *MockGPTServiceInvokeCall { +func (c *ServiceInvokeCall) DoAndReturn(f func(context.Context, domain.GPTRequest) (domain.GPTResponse, error)) *ServiceInvokeCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/internal/ai/type.go b/internal/ai/type.go index e40b1e09..077b0243 100644 --- a/internal/ai/type.go +++ b/internal/ai/type.go @@ -2,9 +2,9 @@ package ai import ( "github.com/ecodeclub/webook/internal/ai/internal/domain" - "github.com/ecodeclub/webook/internal/ai/internal/service" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt" ) type GPTRequest = domain.GPTRequest type GPTResponse = domain.GPTResponse -type GPTService = service.GPTService +type GPTService = gpt.Service diff --git a/internal/ai/wire.go b/internal/ai/wire.go index c14a1c93..1d8a18c4 100644 --- a/internal/ai/wire.go +++ b/internal/ai/wire.go @@ -5,32 +5,42 @@ package ai import ( "sync" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/config" + aicredit "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/credit" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/log" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/record" + "github.com/ecodeclub/webook/internal/ai/internal/repository" "github.com/ecodeclub/webook/internal/ai/internal/repository/dao" - "github.com/ecodeclub/webook/internal/ai/internal/service" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/biz" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/config" - aiCredit "github.com/ecodeclub/webook/internal/ai/internal/service/handler/credit" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/gpt" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/gpt/getter" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/gpt/sdk" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/log" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/response" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/simple" "github.com/ecodeclub/webook/internal/credit" "github.com/ego-component/egorm" "github.com/google/wire" "gorm.io/gorm" ) -func InitModule(db *egorm.Component, - aisdk sdk.GPTSdk, creditSvc credit.Service) (*Module, error) { - wire.Build(InitGPTDAO, +func InitModule(db *egorm.Component, creditSvc *credit.Module) (*Module, error) { + wire.Build( + gpt.NewGPTService, repository.NewGPTLogRepo, - InitHandlers, - service.NewGPTService, + repository.NewGPTCreditLogRepo, + repository.NewCachedConfigRepository, + + InitGPTCreditLogDAO, + dao.NewGORMGPTLogDAO, + dao.NewGORMConfigDAO, + + config.NewBuilder, + log.NewHandler, + record.NewHandler, + aicredit.NewHandlerBuilder, + + InitHandlerFacade, + InitCommonHandlers, + InitZhipu, + wire.Struct(new(Module), "*"), + wire.FieldsOf(new(*credit.Module), "Svc"), ) return new(Module), nil } @@ -46,33 +56,7 @@ func InitTableOnce(db *gorm.DB) { }) } -func InitGPTDAO(db *egorm.Component) dao.GPTLogDAO { +func InitGPTCreditLogDAO(db *egorm.Component) dao.GPTCreditDAO { InitTableOnce(db) - return dao.NewGPTLogDAO(db) -} - -func InitGptHandler(sdk1 sdk.GPTSdk) *gpt.Handler { - sdkGetter := getter.NewPollingGetter([]sdk.GPTSdk{sdk1}) - gptHandler, err := gpt.NewHandler(sdkGetter) - if err != nil { - panic(err) - } - return gptHandler -} - -func InitHandlers(repo repository.GPTLogRepo, sdk1 sdk.GPTSdk, creditSvc credit.Service) []handler.GptHandler { - logHandler := log.NewHandler() - creditHandler := aiCredit.NewHandler(creditSvc, repo) - gptHandler := InitGptHandler(sdk1) - configHandler := config.InitHandler() - simpleHandler := simple.InitHandler(logHandler, creditHandler, gptHandler) - bizHandler := biz.NewHandler(map[string]handler.GptHandler{ - "simple": simpleHandler, - }) - responseHandler := response.NewHandler(repo) - return []handler.GptHandler{ - responseHandler, - configHandler, - bizHandler, - } + return dao.NewGPTCreditLogDAO(db) } diff --git a/internal/ai/wire_gen.go b/internal/ai/wire_gen.go index fdb43552..c11b861c 100644 --- a/internal/ai/wire_gen.go +++ b/internal/ai/wire_gen.go @@ -1,6 +1,6 @@ // Code generated by Wire. DO NOT EDIT. -//go:generate go run -mod=mod github.com/google/wire/cmd/wire +//go:generate go run github.com/google/wire/cmd/wire //go:build !wireinject // +build !wireinject @@ -11,17 +11,11 @@ import ( "github.com/ecodeclub/webook/internal/ai/internal/repository" "github.com/ecodeclub/webook/internal/ai/internal/repository/dao" - service2 "github.com/ecodeclub/webook/internal/ai/internal/service" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/biz" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/config" - credit2 "github.com/ecodeclub/webook/internal/ai/internal/service/handler/credit" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/gpt" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/gpt/getter" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/gpt/sdk" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/log" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/response" - "github.com/ecodeclub/webook/internal/ai/internal/service/handler/simple" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/config" + credit2 "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/credit" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/log" + "github.com/ecodeclub/webook/internal/ai/internal/service/gpt/handler/record" "github.com/ecodeclub/webook/internal/credit" "github.com/ego-component/egorm" "gorm.io/gorm" @@ -29,11 +23,22 @@ import ( // Injectors from wire.go: -func InitModule(db *gorm.DB, aisdk sdk.GPTSdk, creditSvc credit.Service) (*Module, error) { - gptLogDAO := InitGPTDAO(db) - gptLogRepo := repository.NewGPTLogRepo(gptLogDAO) - v := InitHandlers(gptLogRepo, aisdk, creditSvc) - gptService := service2.NewGPTService(v) +func InitModule(db *gorm.DB, creditSvc *credit.Module) (*Module, error) { + handlerBuilder := log.NewHandler() + configDAO := dao.NewGORMConfigDAO(db) + configRepository := repository.NewCachedConfigRepository(configDAO) + configHandlerBuilder := config.NewBuilder(configRepository) + service := creditSvc.Svc + gptCreditDAO := InitGPTCreditLogDAO(db) + gptCreditLogRepo := repository.NewGPTCreditLogRepo(gptCreditDAO) + creditHandlerBuilder := credit2.NewHandlerBuilder(service, gptCreditLogRepo) + gptRecordDAO := dao.NewGORMGPTLogDAO(db) + gptLogRepo := repository.NewGPTLogRepo(gptRecordDAO) + recordHandlerBuilder := record.NewHandler(gptLogRepo) + v := InitCommonHandlers(handlerBuilder, configHandlerBuilder, creditHandlerBuilder, recordHandlerBuilder) + handler := InitZhipu() + facadeHandler := InitHandlerFacade(v, handler) + gptService := gpt.NewGPTService(facadeHandler) module := &Module{ Svc: gptService, } @@ -53,33 +58,7 @@ func InitTableOnce(db *gorm.DB) { }) } -func InitGPTDAO(db *egorm.Component) dao.GPTLogDAO { +func InitGPTCreditLogDAO(db *egorm.Component) dao.GPTCreditDAO { InitTableOnce(db) - return dao.NewGPTLogDAO(db) -} - -func InitGptHandler(sdk1 sdk.GPTSdk) *gpt.Handler { - sdkGetter := getter.NewPollingGetter([]sdk.GPTSdk{sdk1}) - gptHandler, err := gpt.NewHandler(sdkGetter) - if err != nil { - panic(err) - } - return gptHandler -} - -func InitHandlers(repo repository.GPTLogRepo, sdk1 sdk.GPTSdk, creditSvc credit.Service) []handler.GptHandler { - logHandler := log.NewHandler() - creditHandler := credit2.NewHandler(creditSvc, repo) - gptHandler := InitGptHandler(sdk1) - configHandler := config.InitHandler() - simpleHandler := simple.InitHandler(logHandler, creditHandler, gptHandler) - bizHandler := biz.NewHandler(map[string]handler.GptHandler{ - "simple": simpleHandler, - }) - responseHandler := response.NewHandler(repo) - return []handler.GptHandler{ - responseHandler, - configHandler, - bizHandler, - } + return dao.NewGPTCreditLogDAO(db) } diff --git a/internal/question/internal/domain/examine.go b/internal/question/internal/domain/examine.go index dbd7c5be..8adfb2af 100644 --- a/internal/question/internal/domain/examine.go +++ b/internal/question/internal/domain/examine.go @@ -21,7 +21,7 @@ type ExamineResult struct { RawResult string // 使用的 token 数量 - Tokens int + Tokens int64 // 花费的金额 Amount int64 Tid string diff --git a/internal/question/internal/integration/admin_handler_test.go b/internal/question/internal/integration/admin_handler_test.go index 01db0eef..827643cf 100644 --- a/internal/question/internal/integration/admin_handler_test.go +++ b/internal/question/internal/integration/admin_handler_test.go @@ -26,6 +26,8 @@ import ( "testing" "time" + "github.com/ecodeclub/webook/internal/ai" + "github.com/ecodeclub/webook/internal/permission" "github.com/ecodeclub/webook/internal/interactive" @@ -99,7 +101,7 @@ func (s *AdminHandlerTestSuite) SetupSuite() { return res, nil }).AnyTimes() - module, err := startup.InitModule(s.producer, intrModule, &permission.Module{}) + module, err := startup.InitModule(s.producer, intrModule, &permission.Module{}, &ai.Module{}) require.NoError(s.T(), err) econf.Set("server", map[string]any{"contextTimeout": "1s"}) server := egin.Load("server").Build() diff --git a/internal/question/internal/integration/admin_set_handler_test.go b/internal/question/internal/integration/admin_set_handler_test.go index 4a2b4650..21d32fa8 100644 --- a/internal/question/internal/integration/admin_set_handler_test.go +++ b/internal/question/internal/integration/admin_set_handler_test.go @@ -25,6 +25,8 @@ import ( "testing" "time" + "github.com/ecodeclub/webook/internal/ai" + "github.com/ecodeclub/webook/internal/permission" "github.com/ecodeclub/ecache" @@ -64,7 +66,8 @@ func (s *AdminSetHandlerTestSuite) SetupSuite() { intrModule := &interactive.Module{} - module, err := startup.InitModule(s.producer, intrModule, &permission.Module{}) + module, err := startup.InitModule(s.producer, intrModule, + &permission.Module{}, &ai.Module{}) require.NoError(s.T(), err) econf.Set("server", map[string]any{"contextTimeout": "1s"}) server := egin.Load("server").Build() diff --git a/internal/question/internal/integration/examine_handler_test.go b/internal/question/internal/integration/examine_handler_test.go index 10b040f2..cd6241fd 100644 --- a/internal/question/internal/integration/examine_handler_test.go +++ b/internal/question/internal/integration/examine_handler_test.go @@ -22,6 +22,10 @@ import ( "testing" "time" + "github.com/ecodeclub/webook/internal/ai" + aimocks "github.com/ecodeclub/webook/internal/ai/mocks" + "go.uber.org/mock/gomock" + "github.com/ecodeclub/webook/internal/permission" "github.com/ecodeclub/ekit/iox" @@ -50,7 +54,16 @@ type ExamineHandlerTest struct { } func (s *ExamineHandlerTest) SetupSuite() { - module, err := startup.InitModule(nil, &interactive.Module{}, &permission.Module{}) + ctrl := gomock.NewController(s.T()) + aiSvc := aimocks.NewMockService(ctrl) + aiSvc.EXPECT().Invoke(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req ai.GPTRequest) (ai.GPTResponse, error) { + return ai.GPTResponse{ + Tokens: req.Uid, + Amount: req.Uid, + Answer: "评分:15K", + }, nil + }).AnyTimes() + module, err := startup.InitModule(nil, &interactive.Module{}, &permission.Module{}, &ai.Module{Svc: aiSvc}) require.NoError(s.T(), err) hdl := module.ExamineHdl s.db = testioc.InitDB() diff --git a/internal/question/internal/integration/handler_test.go b/internal/question/internal/integration/handler_test.go index 94d72ef1..0bbdaf31 100644 --- a/internal/question/internal/integration/handler_test.go +++ b/internal/question/internal/integration/handler_test.go @@ -24,6 +24,8 @@ import ( "testing" "time" + "github.com/ecodeclub/webook/internal/ai" + "github.com/ecodeclub/webook/internal/permission" permissionmocks "github.com/ecodeclub/webook/internal/permission/mocks" "github.com/ecodeclub/webook/internal/question/internal/errs" @@ -101,7 +103,8 @@ func (s *HandlerTestSuite) SetupSuite() { return perm.BizID%2 == 0, nil }).AnyTimes() - module, err := startup.InitModule(producer, intrModule, &permission.Module{Svc: permSvc}) + module, err := startup.InitModule(producer, intrModule, + &permission.Module{Svc: permSvc}, &ai.Module{}) require.NoError(s.T(), err) econf.Set("server", map[string]any{"contextTimeout": "1s"}) server := egin.Load("server").Build() diff --git a/internal/question/internal/integration/knowledge_job_starter_test.go b/internal/question/internal/integration/knowledge_job_starter_test.go index bcb06bf1..3575bd85 100644 --- a/internal/question/internal/integration/knowledge_job_starter_test.go +++ b/internal/question/internal/integration/knowledge_job_starter_test.go @@ -23,6 +23,8 @@ import ( "os" "time" + "github.com/ecodeclub/webook/internal/ai" + "github.com/ecodeclub/webook/internal/interactive" "github.com/ecodeclub/webook/internal/permission" "github.com/ecodeclub/webook/internal/question/internal/domain" @@ -43,7 +45,7 @@ type KnowledgeJobStarterTestSuite struct { } func (s *KnowledgeJobStarterTestSuite) SetupSuite() { - module, err := startup.InitModule(nil, &interactive.Module{}, &permission.Module{}) + module, err := startup.InitModule(nil, &interactive.Module{}, &permission.Module{}, &ai.Module{}) require.NoError(s.T(), err) s.starter = module.KnowledgeJobStarter s.db = testioc.InitDB() diff --git a/internal/question/internal/integration/set_handler_test.go b/internal/question/internal/integration/set_handler_test.go index 69af6bf4..412cb7e7 100644 --- a/internal/question/internal/integration/set_handler_test.go +++ b/internal/question/internal/integration/set_handler_test.go @@ -24,6 +24,8 @@ import ( "testing" "time" + "github.com/ecodeclub/webook/internal/ai" + "github.com/ecodeclub/webook/internal/permission" "github.com/ecodeclub/ecache" @@ -89,7 +91,7 @@ func (s *SetHandlerTestSuite) SetupSuite() { return res, nil }).AnyTimes() - module, err := startup.InitModule(s.producer, intrModule, &permission.Module{}) + module, err := startup.InitModule(s.producer, intrModule, &permission.Module{}, &ai.Module{}) require.NoError(s.T(), err) econf.Set("server", map[string]any{"contextTimeout": "1s"}) server := egin.Load("server").Build() diff --git a/internal/question/internal/integration/startup/wire.go b/internal/question/internal/integration/startup/wire.go index 27d3125e..6bcc509a 100644 --- a/internal/question/internal/integration/startup/wire.go +++ b/internal/question/internal/integration/startup/wire.go @@ -19,6 +19,8 @@ package startup import ( "os" + "github.com/ecodeclub/webook/internal/ai" + "github.com/ecodeclub/webook/internal/interactive" "github.com/ecodeclub/webook/internal/permission" baguwen "github.com/ecodeclub/webook/internal/question" @@ -35,6 +37,7 @@ import ( func InitModule(p event.SyncDataToSearchEventProducer, intrModule *interactive.Module, permModule *permission.Module, + aiModule *ai.Module, ) (*baguwen.Module, error) { wire.Build( testioc.BaseSet, @@ -42,6 +45,7 @@ func InitModule(p event.SyncDataToSearchEventProducer, event.NewInteractiveEventProducer, wire.FieldsOf(new(*interactive.Module), "Svc"), wire.FieldsOf(new(*permission.Module), "Svc"), + wire.FieldsOf(new(*ai.Module), "Svc"), ) return new(baguwen.Module), nil } diff --git a/internal/question/internal/integration/startup/wire_gen.go b/internal/question/internal/integration/startup/wire_gen.go index 8360c683..f56b0f18 100644 --- a/internal/question/internal/integration/startup/wire_gen.go +++ b/internal/question/internal/integration/startup/wire_gen.go @@ -9,6 +9,7 @@ package startup import ( "os" + "github.com/ecodeclub/webook/internal/ai" "github.com/ecodeclub/webook/internal/interactive" "github.com/ecodeclub/webook/internal/permission" baguwen "github.com/ecodeclub/webook/internal/question" @@ -25,7 +26,7 @@ import ( // Injectors from wire.go: -func InitModule(p event.SyncDataToSearchEventProducer, intrModule *interactive.Module, permModule *permission.Module) (*baguwen.Module, error) { +func InitModule(p event.SyncDataToSearchEventProducer, intrModule *interactive.Module, permModule *permission.Module, aiModule *ai.Module) (*baguwen.Module, error) { db := testioc.InitDB() questionDAO := baguwen.InitQuestionDAO(db) ecacheCache := testioc.InitCache() @@ -45,7 +46,8 @@ func InitModule(p event.SyncDataToSearchEventProducer, intrModule *interactive.M service2 := intrModule.Svc examineDAO := dao.NewGORMExamineDAO(db) examineRepository := repository.NewCachedExamineRepository(examineDAO) - examineService := service.NewGPTExamineService(repositoryRepository, examineRepository) + gptService := aiModule.Svc + examineService := service.NewGPTExamineService(repositoryRepository, examineRepository, gptService) service3 := permModule.Svc handler := web.NewHandler(service2, examineService, service3, serviceService) questionSetHandler := web.NewQuestionSetHandler(questionSetService, examineService, service2) diff --git a/internal/question/internal/repository/dao/examine_types.go b/internal/question/internal/repository/dao/examine_types.go index e4162911..53fba296 100644 --- a/internal/question/internal/repository/dao/examine_types.go +++ b/internal/question/internal/repository/dao/examine_types.go @@ -26,7 +26,7 @@ type ExamineRecord struct { // 原始的 AI 回答 RawResult string // 冗余字段,使用的 tokens 数量 - Tokens int + Tokens int64 // 冗余字段,花费的金额 Amount int64 diff --git a/internal/question/internal/service/examine.go b/internal/question/internal/service/examine.go index d2500297..a0b0b46c 100644 --- a/internal/question/internal/service/examine.go +++ b/internal/question/internal/service/examine.go @@ -109,10 +109,11 @@ func (svc *GPTExamineService) parseExamineResult(answer string) domain.Result { func NewGPTExamineService( queRepo repository.Repository, repo repository.ExamineRepository, + aiSvc ai.GPTService, ) ExamineService { return &GPTExamineService{ queRepo: queRepo, repo: repo, - aiSvc: &AiService{}, + aiSvc: aiSvc, } } diff --git a/internal/question/internal/web/examine_vo.go b/internal/question/internal/web/examine_vo.go index 668a7493..0f5cff49 100644 --- a/internal/question/internal/web/examine_vo.go +++ b/internal/question/internal/web/examine_vo.go @@ -28,7 +28,7 @@ type ExamineResult struct { RawResult string `json:"rawResult"` // 使用的 token 数量 - Tokens int `json:"tokens"` + Tokens int64 `json:"tokens"` // 花费的金额 Amount int64 `json:"amount"` } diff --git a/internal/question/wire.go b/internal/question/wire.go index 85abacb2..70a62a6b 100644 --- a/internal/question/wire.go +++ b/internal/question/wire.go @@ -19,6 +19,8 @@ package baguwen import ( "sync" + "github.com/ecodeclub/webook/internal/ai" + "github.com/gotomicro/ego/core/econf" "github.com/ecodeclub/webook/internal/question/internal/job" @@ -52,6 +54,7 @@ func InitModule(db *egorm.Component, intrModule *interactive.Module, ec ecache.Cache, perm *permission.Module, + aiModule *ai.Module, q mq.MQ) (*Module, error) { wire.Build(InitQuestionDAO, cache.NewQuestionECache, @@ -73,6 +76,7 @@ func InitModule(db *egorm.Component, wire.FieldsOf(new(*interactive.Module), "Svc"), wire.FieldsOf(new(*permission.Module), "Svc"), + wire.FieldsOf(new(*ai.Module), "Svc"), wire.Struct(new(Module), "*"), ) diff --git a/internal/question/wire_gen.go b/internal/question/wire_gen.go index 460414c4..9bffa55d 100644 --- a/internal/question/wire_gen.go +++ b/internal/question/wire_gen.go @@ -11,6 +11,7 @@ import ( "github.com/ecodeclub/ecache" "github.com/ecodeclub/mq-api" + "github.com/ecodeclub/webook/internal/ai" "github.com/ecodeclub/webook/internal/interactive" "github.com/ecodeclub/webook/internal/permission" "github.com/ecodeclub/webook/internal/question/internal/event" @@ -28,7 +29,7 @@ import ( // Injectors from wire.go: -func InitModule(db *gorm.DB, intrModule *interactive.Module, ec ecache.Cache, perm *permission.Module, q mq.MQ) (*Module, error) { +func InitModule(db *gorm.DB, intrModule *interactive.Module, ec ecache.Cache, perm *permission.Module, aiModule *ai.Module, q mq.MQ) (*Module, error) { questionDAO := InitQuestionDAO(db) questionCache := cache.NewQuestionECache(ec) repositoryRepository := repository.NewCacheRepository(questionDAO, questionCache) @@ -49,7 +50,8 @@ func InitModule(db *gorm.DB, intrModule *interactive.Module, ec ecache.Cache, pe service2 := intrModule.Svc examineDAO := dao.NewGORMExamineDAO(db) examineRepository := repository.NewCachedExamineRepository(examineDAO) - examineService := service.NewGPTExamineService(repositoryRepository, examineRepository) + gptService := aiModule.Svc + examineService := service.NewGPTExamineService(repositoryRepository, examineRepository, gptService) service3 := perm.Svc handler := web.NewHandler(service2, examineService, service3, serviceService) questionSetHandler := web.NewQuestionSetHandler(questionSetService, examineService, service2) diff --git a/ioc/wire.go b/ioc/wire.go index 7ab16485..f1c0342f 100644 --- a/ioc/wire.go +++ b/ioc/wire.go @@ -17,6 +17,7 @@ package ioc import ( + "github.com/ecodeclub/webook/internal/ai" "github.com/ecodeclub/webook/internal/cases" "github.com/ecodeclub/webook/internal/cos" "github.com/ecodeclub/webook/internal/credit" @@ -83,6 +84,7 @@ func InitApp() (*App, error) { wire.FieldsOf(new(*search.Module), "Hdl"), roadmap.InitModule, wire.FieldsOf(new(*roadmap.Module), "Hdl", "AdminHdl"), + ai.InitModule, initLocalActiveLimiterBuilder, initCronJobs, // 这两个顺序不要换 diff --git a/ioc/wire_gen.go b/ioc/wire_gen.go index ba60ed4d..e177b9a5 100644 --- a/ioc/wire_gen.go +++ b/ioc/wire_gen.go @@ -7,6 +7,7 @@ package ioc import ( + "github.com/ecodeclub/webook/internal/ai" "github.com/ecodeclub/webook/internal/cases" "github.com/ecodeclub/webook/internal/cos" "github.com/ecodeclub/webook/internal/credit" @@ -54,7 +55,15 @@ func InitApp() (*App, error) { return nil, err } cache := InitCache(cmdable) - baguwenModule, err := baguwen.InitModule(db, interactiveModule, cache, permissionModule, mq) + creditModule, err := credit.InitModule(db, mq, cache) + if err != nil { + return nil, err + } + aiModule, err := ai.InitModule(db, creditModule) + if err != nil { + return nil, err + } + baguwenModule, err := baguwen.InitModule(db, interactiveModule, cache, permissionModule, aiModule, mq) if err != nil { return nil, err } @@ -83,10 +92,6 @@ func InitApp() (*App, error) { return nil, err } handler7 := productModule.Hdl - creditModule, err := credit.InitModule(db, mq, cache) - if err != nil { - return nil, err - } paymentModule, err := payment.InitModule(db, mq, cache, creditModule) if err != nil { return nil, err