-
Notifications
You must be signed in to change notification settings - Fork 0
/
intent_classifier.py
36 lines (30 loc) · 1.46 KB
/
intent_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from transformers import BertTokenizer, BertForSequenceClassification
# Load the pre-trained BERT model and tokenizer
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
# Load the fine-tuned model weights
model.load_state_dict(torch.load("multiclass_bert_model.pth"))
model.eval() # Set the model to evaluation mode
# Define a function to predict the sentiment of a single text
def predict_sentiment(text):
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1).detach().numpy()[0]
predicted_label = torch.argmax(logits, dim=1).item()
# max_confidence = probabilities[predicted_label]
max_confidence=1
if max_confidence < 0.7:
predicted_label_name = "unknown"
else:
# Convert label index to label name (assuming you have a list of labels)
labels = ["Question","Action"] # Example list of labels
predicted_label_name = labels[predicted_label]
return predicted_label_name,max_confidence
# Example usage:
# text_to_predict = "tell priyansh to give me 20 rupees" # Replace with your text
# predicted_class = predict_sentiment(text_to_predict)
# print("predicted_sentiment = ", predicted_class)