diff --git a/README.md b/README.md index 5e4f266..a7ef814 100644 --- a/README.md +++ b/README.md @@ -29,9 +29,7 @@ The configuration file is saved in `config.json`. ## Plans -- [ ] Support task `label_initial_prompt` - - This task is a bit hard to find for my language configuration, I'll keep finding it. +- [x] ~~Support task `label_initial_prompt`~~ - [ ] Support other language models (maybe) There have been many excellent LLMs coming out recently, like GPT4All. I may try to support them in my free time(but no guarantees). if you have ideas, PRs are welcome! diff --git a/README_zh.md b/README_zh.md index 2350348..8316053 100644 --- a/README_zh.md +++ b/README_zh.md @@ -29,9 +29,7 @@ ## 计划 -- [ ] 支持任务 `label_initial_prompt` - - 目前中文的任务还是比较少的,尤其这个类别很少出现。等下次我遇到了再做支持,也欢迎大家 PR。 +- [x] ~~支持任务 `label_initial_prompt`~~ - [ ] 支持其他语言模型(或许) 最近出现了许多优秀的 LLM,比如 GPT4All,我有空的话可能会尝试支持它们(但不能保证)。如果您有想法,欢迎提 PR! diff --git a/model/model.go b/model/model.go index d174fa6..e6cf1b8 100644 --- a/model/model.go +++ b/model/model.go @@ -78,6 +78,20 @@ type OALabelPrompterReplyTask struct { UserId string `json:"userId"` } +type OALabelInitialPromptTask struct { + Id string `json:"id"` + Mode string `json:"mode"` + Type string `json:"type"` + Labels []OARandomTaskLabel `json:"labels"` + Prompt string `json:"prompt"` + MessageId string `json:"message_id"` + Disposition string `json:"disposition"` + Conversation OAConversation `json:"conversation"` + ValidLabels []OALabel `json:"valid_labels"` + MandatoryLabels []OALabel `json:"mandatory_labels"` + UserId string `json:"userId"` +} + type OAAssistantReplyTask struct { Id string `json:"id"` Type string `json:"type"` @@ -162,3 +176,15 @@ type PostBodyRankAssistant struct { UpdateType string `json:"update_type"` Content PostBodyRankAssistantContent `json:"content"` } + +type PostBodyLabelInitialPromptContent struct { + Labels map[OALabel]float32 `json:"labels"` + MessageId string `json:"message_id"` + Text string `json:"text"` +} +type PostBodyLabelInitialPrompt struct { + Id string `json:"id"` + Lang string `json:"lang"` + UpdateType string `json:"update_type"` + Content PostBodyLabelInitialPromptContent `json:"content"` +} diff --git a/open-assistant.go b/open-assistant.go index 50e2d20..d76b978 100644 --- a/open-assistant.go +++ b/open-assistant.go @@ -40,8 +40,8 @@ func RefreshCookie() error { return err } cs := resp.Header()["Set-Cookie"] - if len(cs) == 0 { - return fmt.Errorf("no cookie") + if len(cs) <= 1 { + return fmt.Errorf("no cookie. Please consider going to https://open-assistant.io , login, and update your cookie in config.json") } c := strings.Join(cs, "") return model.UpdateCookie(c) @@ -150,7 +150,7 @@ func LabelPrompterReply(id string, task model.OALabelPrompterReplyTask) error { text += fmt.Sprintf("User: %s", task.Reply) } - t, err := Complete(text, `You are a powerful fine-tuner of Open Assistant, an open-source LLM. You will be given conversations between a user and the model, and you need to label the user's last reply and return a JSON string.You should evaluate the conversations based on the following criteria: + t, err := Complete(text, `You are a powerful fine-tuner of Open Assistant, an open-source LLM. You will be given conversations between a user and the model, and you need to label the user's last reply and return a JSON string. You should evaluate the conversations based on the following criteria: - Spam: 0/1, whether the conversation contains spam / ads / porn / politics / etc. - Not Appropriate: 0/1, whether the response is reasonable for the user's question - pii: 0/1 @@ -206,6 +206,67 @@ You must return a JSON string, DO NOT include any other characters, DO NOT expla return nil } +func LabelInitialPrompt(id string, task model.OALabelInitialPromptTask) error { + logx.Infof("LabelInitialPrompt") + text := fmt.Sprintf("Prompt: %s\n\nUser's language code: %s", task.Prompt, model.Conf.Language) + + t, err := Complete(text, `You are a powerful fine-tuner of Open Assistant, an open-source LLM. You will be given a prompt from the user, and you need to label it and return a JSON string. You should evaluate the prompt based on the following criteria: +- Spam: 0/1, whether the message contains spam / ads / porn / politics / etc. +- Not Appropriate: 0/1, whether the message is offensive / not respectful +- pii: 0/1 +- Hate Speech: 0/1, whether the prompt is aggressive / not respectful +- Sexual Content: 0/1 +- Quality: 0-1, step 0.25, how well the response is written respecting grammar, spelling, use of words, etc. +- Lang Mismatch: 0-1, step 0.25, whether the prompt is in the same language as the user's language +- Creativity: 0-1, step 0.25, how less is the prompt +- Humor: 0-1, step 0.25 +- Toxicity: 0-1, step 0.25, how aggressive is the prompt +- Violence: 0-1, step 0.25 + +You must return a JSON string, DO NOT include any other characters, DO NOT explain. Use snake_case for the keys.`) + + if err != nil { + return err + } + + logx.Infof("LabelInitialPrompt: %s", t) + labels, err := GetLabelsFromChatGPT(t) + if err != nil { + return err + } + + body := model.PostBodyLabelInitialPrompt{ + Id: id, + Lang: model.Conf.Language, + UpdateType: "text_labels", + Content: model.PostBodyLabelInitialPromptContent{ + Labels: labels, + Text: "unused?", + MessageId: task.MessageId, + }, + } + + resp, err := rty.R(). + SetHeaders(map[string]string{ + "Cookie": model.Conf.OaCookie, + "Content-Type": "application/json", + }). + SetBody(body). + Post("https://open-assistant.io/api/update_task") + + if err != nil { + return err + } + + respStr := string(resp.Body()) + if respStr == "" { + logx.Infof("LabelInitialPrompt: OK!") + } else { + logx.Errorf("LabelInitialPrompt: %s", respStr) + } + return nil +} + func AssistantReply(id string, task model.OAAssistantReplyTask) error { logx.Infof("AssistantReply") text := "" @@ -374,9 +435,20 @@ func StartTask() error { } var task model.OARandomTaskResponse - err = jsonx.Unmarshal(resp.Body(), &task) - if task.Task == nil { - logx.Infof("GetTasks: get task failed: %s", string(resp.Body())) + body := resp.Body() + err = jsonx.Unmarshal(body, &task) + if resp.StatusCode() == 403 { + logx.Errorf("GetTasks: cookie may have expired (403 Forbidden)") + logx.Errorf("Please login to https://open-assistant.io/dashboard and update your cookie in config.json") + return nil + } else if task.Task == nil { + var _json map[string]interface{} + _ = jsonx.Unmarshal(body, &_json) + if strings.Contains(_json["message"].(string), "No tasks") { + logx.Infof("GetTasks: no tasks at this time") + } + logx.Errorf("GetTasks: get task failed: %s", _json["message"]) + logx.Errorf("Please check your network, or login to https://open-assistant.io/dashboard and update your cookie in config.json") return nil } @@ -403,9 +475,16 @@ func StartTask() error { var ch model.OARankAssistantRepliesTask _ = jsonx.Unmarshal(j, &ch) return RankAssistantReplies(task.Id, ch) + } else if t["type"] == "label_initial_prompt" { + var ch model.OALabelInitialPromptTask + _ = jsonx.Unmarshal(j, &ch) + return LabelInitialPrompt(task.Id, ch) } else { logx.Infof("GetTasks: unknown task type: %s", t["type"]) - CancelTask(task.Id) + err = CancelTask(task.Id) + if err != nil { + return fmt.Errorf("GetTasks: cancel task failed: %s", err) + } } return nil