-
Notifications
You must be signed in to change notification settings - Fork 0
/
Claude_triage_GeneralUser.py
56 lines (36 loc) · 2.77 KB
/
Claude_triage_GeneralUser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
## PREDICT TRIAGE/ACUITY GENERAL USER CASE
## import libraries
import pandas as pd
from tqdm import tqdm
import os
import boto3
from langchain.prompts import PromptTemplate
from langchain_aws import ChatBedrock
## Import Functions
from functions.LLM_predictions import get_prediction_GeneralUser
## Load Data from create_ground_truth_specialty.py
df = pd.read_csv("MIMIC-IV-Ext-Triage.csv")
## Define the prompt template
prompt = """You are a nurse with emergency and triage experience. Using the patient's history of present illness and his information, determine the triage level based on the Emergency Severity Index (ESI), ranging from ESI level 1 (highest acuity) to ESI level 5 (lowest acuity): 1: Assign if the patient requires immediate lifesaving intervention. 2: Assign if the patient is in a high-risk situation (e.g., confused, lethargic, disoriented, or experiencing severe pain/distress) 3: Assign if the patient requires two or more diagnostic or therapeutic interventions and their vital signs are within acceptable limits for non-urgent care. 4: Assign if the patient requires one diagnostic or therapeutic intervention (e.g., lab test, imaging, or EKG). 5: Assign if the patient does not require any diagnostic or therapeutic interventions beyond a physical exam (e.g., no labs, imaging, or wound care).
History of present illness: {HPI} and patient info: {patient_info}. Respond with the level in an <acuity> tag."""
## set AWS credentials
os.environ["AWS_ACCESS_KEY_ID"]="Enter your AWS Access Key ID"
os.environ["AWS_SECRET_ACCESS_KEY"]="Enter your AWS Secret Access Key"
prompt_chain = PromptTemplate(template=prompt,input_variables=["hpi", "patient_info"])
client = boto3.client(service_name="bedrock-runtime", region_name=str("us-east-1"))
## Claude Sonnet 3.5
llm_claude35 = ChatBedrock(model_id="anthropic.claude-3-5-sonnet-20240620-v1:0", model_kwargs={"temperature": 0},client=client)
chain_claude35 = prompt_chain | llm_claude35
## Claude Sonnet 3
llm_claude3 = ChatBedrock(model_id="anthropic.claude-3-sonnet-20240229-v1:0", model_kwargs={"temperature": 0},client=client)
chain_claude3 = prompt_chain | llm_claude3
## Claude 3 Haiku
llm_haiku = ChatBedrock(model_id="anthropic.claude-3-haiku-20240307-v1:0", model_kwargs={"temperature": 0},client=client)
chain_haiku = prompt_chain | llm_haiku
tqdm.pandas()
df['triage_Claude3.5'] = df.progress_apply(lambda row: get_prediction_GeneralUser(row, chain_claude35), axis=1)
df.to_csv('MIMIC-IV-Ext-Triage.csv', index=False)
df['triage_Claude3'] = df.progress_apply(lambda row: get_prediction_GeneralUser(row, chain_claude3), axis=1)
df.to_csv('MIMIC-IV-Ext-Triage.csv', index=False)
df['triage_Haiku'] = df.progress_apply(lambda row: get_prediction_GeneralUser(row, chain_haiku), axis=1)
df.to_csv('MIMIC-IV-Ext-Triage.csv', index=False)