Skip to content

Commit

Permalink
feat(huggingface): setting for removing prompt from generated commit …
Browse files Browse the repository at this point in the history
…message

Closes #294
  • Loading branch information
Blarc committed Dec 1, 2024
1 parent 36d38a8 commit 23e2195
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- More options for configuring LLM clients.
- **Use the chosen LLM client icon as the generate commit message action's icon.**
- Option to stop the commit message generation by clicking the action icon again.
- Setting for HuggingFace client to automatically remove prompt from the generated commit message.

### Fixed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ abstract class LLMClientConfiguration(
getSharedState().modelIds.add(modelId)
}

open fun setCommitMessage(commitMessage: CommitMessage, prompt: String, result: String) {
commitMessage.setCommitMessage(result)
}

abstract fun generateCommitMessage(commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project)

abstract fun getGenerateCommitMessageJob(): Job?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: Coro

makeRequest(clientConfiguration, prompt, onSuccess = {
withContext(Dispatchers.EDT) {
commitMessage.setCommitMessage(it)
clientConfiguration.setCommitMessage(commitMessage, prompt, it)
}
AppSettings2.instance.recordHit()
}, onError = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class HuggingFaceClientConfiguration : LLMClientConfiguration(
var tokenIsStored: Boolean = false
@Transient
var token: String? = null
@Attribute
var removePrompt: Boolean = false

companion object {
const val CLIENT_NAME = "HuggingFace"
Expand All @@ -45,6 +47,15 @@ class HuggingFaceClientConfiguration : LLMClientConfiguration(
return HuggingFaceClientSharedState.getInstance()
}

override fun setCommitMessage(commitMessage: CommitMessage, prompt: String, result: String) {
var newResult = result
if (removePrompt) {
// https://github.com/Blarc/ai-commits-intellij-plugin/issues/294
newResult = result.substring(prompt.length+1)
}
super.setCommitMessage(commitMessage, prompt, newResult)
}

override fun generateCommitMessage(commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project) {
return HuggingFaceClientService.getInstance().generateCommitMessage(this, commitWorkflowHandler, commitMessage, project)
}
Expand All @@ -65,6 +76,7 @@ class HuggingFaceClientConfiguration : LLMClientConfiguration(
copy.timeout = timeout
copy.waitForModel = waitForModel
copy.maxNewTokens = maxNewTokens
copy.removePrompt = removePrompt
return copy
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class HuggingFaceClientPanel private constructor(
private val tokenPasswordField = JBPasswordField()
private val maxNewTokensTextField = JBTextField()
private val waitForModelCheckBox = JBCheckBox()
private val removePrompt = JBCheckBox()

constructor(configuration: HuggingFaceClientConfiguration) : this(configuration, HuggingFaceClientService.getInstance())

Expand All @@ -27,6 +28,7 @@ class HuggingFaceClientPanel private constructor(
temperatureRow()
maxNewTokens()
waitForModel()
removePrompt()
verifyRow()
}

Expand All @@ -40,6 +42,7 @@ class HuggingFaceClientPanel private constructor(
clientConfiguration.token = String(tokenPasswordField.password)
clientConfiguration.maxNewTokens = maxNewTokensTextField.text.toInt()
clientConfiguration.waitForModel = waitForModelCheckBox.isSelected
clientConfiguration.removePrompt = removePrompt.isSelected
service.verifyConfiguration(clientConfiguration, verifyLabel)
}

Expand Down Expand Up @@ -83,4 +86,18 @@ class HuggingFaceClientPanel private constructor(
.align(AlignX.RIGHT)
}
}

private fun Panel.removePrompt() {
row {
label(message("settings.huggingface.removePrompt"))
.widthGroup("label")
cell(removePrompt)
.bindSelected(clientConfiguration::removePrompt)
.resizableColumn()
.align(Align.FILL)

contextHelp(message("settings.huggingface.removePrompt.comment"))
.align(AlignX.RIGHT)
}
}
}
2 changes: 2 additions & 0 deletions src/main/resources/messages/AiCommitsBundle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,6 @@ settings.huggingface.token.example=hf_fKASPPYLkasgjasKwpSnAASRdasdCdAsddsASSDF
settings.huggingface.maxNewTokens=Max new tokens
settings.huggingface.waitForModel=Wait for model
settings.huggingface.waitModel.comment=When a model is warm, it is ready to be used, and you will get a response relatively quickly. However, some models are cold and need to be loaded before they can be used. In that case, you will get a 503 error. Rather than doing many requests until it is loaded, you can wait for the model to be loaded.
settings.huggingface.removePrompt=Remove prompt
settings.huggingface.removePrompt.comment=Some models return the result prefixed with prompt. Checking this option will remove the prompt from result.

0 comments on commit 23e2195

Please sign in to comment.