diff --git a/.gitignore b/.gitignore
index a25e14a..9a79f19 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,4 +11,8 @@ scripts/*.ps1
scripts/*.sh
**/dist
**/build
-*.log
\ No newline at end of file
+*.log
+benchmark/
+modelTest/
+nc_workspace/
+debug_openai_history.txt
\ No newline at end of file
diff --git a/README.md b/README.md
index 4e4c883..b96bedd 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,7 @@ Run local LLMs on iGPU, APU and CPU (AMD , Intel, and Qualcomm (Coming Soon)). E
| Model architectures | Gemma
Llama \*
Mistral +
Phi
| | |
| Platform | Linux
Windows | | |
| Architecture | x86
x64
| Arm64 | |
-| Hardware Acceleration | CUDA
DirectML
IpexLLM | QNN
ROCm | OpenVINO |
+| Hardware Acceleration | CUDA
DirectML
IpexLLM
OpenVINO | QNN
ROCm | |
\* The Llama model architecture supports similar model families such as CodeLlama, Vicuna, Yi, and more.
@@ -33,22 +33,12 @@ Run local LLMs on iGPU, APU and CPU (AMD , Intel, and Qualcomm (Coming Soon)). E
- [Acknowledgements](#acknowledgements)
## Supported Models (Quick Start)
+ * Onnxruntime DirectML Models [Link](./docs/model/onnxruntime_directml_models.md)
+ * Onnxruntime CPU Models [Link](./docs/model/onnxruntime_cpu_models.md)
+ * Ipex-LLM Models [Link](./docs/model/ipex_models.md)
+ * OpenVINO-LLM Models [Link](./docs/model/openvino_models.md)
+ * NPU-LLM Models [Link](./docs/model/npu_models.md)
-| Models | Parameters | Context Length | Link |
-| --- | --- | --- | --- |
-| Gemma-2b-Instruct v1 | 2B | 8192 | [EmbeddedLLM/gemma-2b-it-onnx](https://huggingface.co/EmbeddedLLM/gemma-2b-it-onnx) |
-| Llama-2-7b-chat | 7B | 4096 | [EmbeddedLLM/llama-2-7b-chat-int4-onnx-directml](https://huggingface.co/EmbeddedLLM/llama-2-7b-chat-int4-onnx-directml) |
-| Llama-2-13b-chat | 13B | 4096 | [EmbeddedLLM/llama-2-13b-chat-int4-onnx-directml](https://huggingface.co/EmbeddedLLM/llama-2-13b-chat-int4-onnx-directml) |
-| Llama-3-8b-chat | 8B | 8192 | [EmbeddedLLM/mistral-7b-instruct-v0.3-onnx](https://huggingface.co/EmbeddedLLM/mistral-7b-instruct-v0.3-onnx) |
-| Mistral-7b-v0.3-instruct | 7B | 32768 | [EmbeddedLLM/mistral-7b-instruct-v0.3-onnx](https://huggingface.co/EmbeddedLLM/mistral-7b-instruct-v0.3-onnx) |
-| Phi-3-mini-4k-instruct-062024 | 3.8B | 4096 | [EmbeddedLLM/Phi-3-mini-4k-instruct-062024-onnx](https://huggingface.co/EmbeddedLLM/Phi-3-mini-4k-instruct-062024-onnx/tree/main/onnx/directml/Phi-3-mini-4k-instruct-062024-int4) |
-| Phi3-mini-4k-instruct | 3.8B | 4096 | [microsoft/Phi-3-mini-4k-instruct-onnx](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx) |
-| Phi3-mini-128k-instruct | 3.8B | 128k | [microsoft/Phi-3-mini-128k-instruct-onnx](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct-onnx) |
-| Phi3-medium-4k-instruct | 17B | 4096 | [microsoft/Phi-3-medium-4k-instruct-onnx-directml](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct-onnx-directml) |
-| Phi3-medium-128k-instruct | 17B | 128k | [microsoft/Phi-3-medium-128k-instruct-onnx-directml](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct-onnx-directml) |
-| Openchat-3.6-8b | 8B | 8192 | [EmbeddedLLM/openchat-3.6-8b-20240522-onnx](https://huggingface.co/EmbeddedLLM/openchat-3.6-8b-20240522-onnx) |
-| Yi-1.5-6b-chat | 6B | 32k | [EmbeddedLLM/01-ai_Yi-1.5-6B-Chat-onnx](https://huggingface.co/EmbeddedLLM/01-ai_Yi-1.5-6B-Chat-onnx) |
-| Phi-3-vision-128k-instruct | | 128k | [EmbeddedLLM/Phi-3-vision-128k-instruct-onnx](https://huggingface.co/EmbeddedLLM/Phi-3-vision-128k-instruct-onnx/tree/main/onnx/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4) |
## Getting Started
@@ -70,12 +60,14 @@ Run local LLMs on iGPU, APU and CPU (AMD , Intel, and Qualcomm (Coming Soon)). E
- **CUDA:** `$env:ELLM_TARGET_DEVICE='cuda'; pip install -e .[cuda]`
- **IPEX:** `$env:ELLM_TARGET_DEVICE='ipex'; python setup.py develop`
- **OpenVINO:** `$env:ELLM_TARGET_DEVICE='openvino'; pip install -e .[openvino]`
+ - **NPU:** `$env:ELLM_TARGET_DEVICE='npu'; pip install -e .[npu]`
- **With Web UI**:
- **DirectML:** `$env:ELLM_TARGET_DEVICE='directml'; pip install -e .[directml,webui]`
- **CPU:** `$env:ELLM_TARGET_DEVICE='cpu'; pip install -e .[cpu,webui]`
- **CUDA:** `$env:ELLM_TARGET_DEVICE='cuda'; pip install -e .[cuda,webui]`
- **IPEX:** `$env:ELLM_TARGET_DEVICE='ipex'; python setup.py develop; pip install -r requirements-webui.txt`
- **OpenVINO:** `$env:ELLM_TARGET_DEVICE='openvino'; pip install -e .[openvino,webui]`
+ - **NPU:** `$env:ELLM_TARGET_DEVICE='npu'; pip install -e .[npu,webui]`
- **Linux**
@@ -91,12 +83,14 @@ Run local LLMs on iGPU, APU and CPU (AMD , Intel, and Qualcomm (Coming Soon)). E
- **CUDA:** `ELLM_TARGET_DEVICE='cuda' pip install -e .[cuda]`
- **IPEX:** `ELLM_TARGET_DEVICE='ipex' python setup.py develop`
- **OpenVINO:** `ELLM_TARGET_DEVICE='openvino' pip install -e .[openvino]`
+ - **NPU:** `ELLM_TARGET_DEVICE='npu' pip install -e .[npu]`
- **With Web UI**:
- **DirectML:** `ELLM_TARGET_DEVICE='directml' pip install -e .[directml,webui]`
- **CPU:** `ELLM_TARGET_DEVICE='cpu' pip install -e .[cpu,webui]`
- **CUDA:** `ELLM_TARGET_DEVICE='cuda' pip install -e .[cuda,webui]`
- **IPEX:** `ELLM_TARGET_DEVICE='ipex' python setup.py develop; pip install -r requirements-webui.txt`
- **OpenVINO:** `ELLM_TARGET_DEVICE='openvino' pip install -e .[openvino,webui]`
+ - **NPU:** `ELLM_TARGET_DEVICE='npu' pip install -e .[npu,webui]`
### Launch OpenAI API Compatible Server
@@ -121,7 +115,7 @@ Run local LLMs on iGPU, APU and CPU (AMD , Intel, and Qualcomm (Coming Soon)). E
### Launch Chatbot Web UI
-1. `ellm_chatbot --port 7788 --host localhost --server_port --server_host localhost`. **Note:** To find out more of the supported arguments. `ellm_chatbot --help`.
+1. `ellm_chatbot --port 7788 --host localhost --server_port --server_host localhost --model_name `. **Note:** To find out more of the supported arguments. `ellm_chatbot --help`.
![asset/ellm_chatbot_vid.webp](asset/ellm_chatbot_vid.webp)
@@ -156,6 +150,9 @@ It is an interface that allows you to download and deploy OpenAI API compatible
# OpenVINO
ellm_server --model_path '.\meta-llama_Meta-Llama-3.1-8B-Instruct\' --backend 'openvino' --device 'gpu' --port 5555 --served_model_name 'meta-llama_Meta/Llama-3.1-8B-Instruct'
+
+ # NPU
+ ellm_server --model_path 'microsoft/Phi-3-mini-4k-instruct' --backend 'npu' --device 'npu' --port 5555 --served_model_name 'microsoft/Phi-3-mini-4k-instruct'
```
## Prebuilt OpenAI API Compatible Windows Executable (Alpha)
@@ -168,13 +165,16 @@ _Powershell/Terminal Usage (Use it like `ellm_server`)_:
.\ellm_api_server.exe --model_path
# DirectML
-.\ellm_api_server.exe --model_path 'EmbeddedLLM_Phi-3-mini-4k-instruct-062024-onnx\onnx\directml\Phi-3-mini-4k-instruct-062024-int4' --port 5555
+.\ellm_api_server.exe --model_path 'EmbeddedLLM/Phi-3-mini-4k-instruct-onnx-directml' --port 5555
# IPEX-LLM
.\ellm_api_server.exe --model_path '.\meta-llama_Meta-Llama-3.1-8B-Instruct\' --backend 'ipex' --device 'xpu' --port 5555 --served_model_name 'meta-llama_Meta/Llama-3.1-8B-Instruct'
# OpenVINO
.\ellm_api_server.exe --model_path '.\meta-llama_Meta-Llama-3.1-8B-Instruct\' --backend 'openvino' --device 'gpu' --port 5555 --served_model_name 'meta-llama_Meta/Llama-3.1-8B-Instruct'
+
+# NPU
+.\ellm_api_server.exe --model_path 'microsoft/Phi-3-mini-4k-instruct' --backend 'npu' --device 'npu' --port 5555 --served_model_name 'microsoft/Phi-3-mini-4k-instruct'
```
## Acknowledgements
diff --git a/docs/model/npu_models.md b/docs/model/npu_models.md
new file mode 100644
index 0000000..c1d2b06
--- /dev/null
+++ b/docs/model/npu_models.md
@@ -0,0 +1,15 @@
+# Model Powered by NPU-LLM
+
+## Verified Models
+Verified models can be found from EmbeddedLLM NPU-LLM model collections
+* EmbeddedLLM NPU-LLM Model collections: [link](https://huggingface.co/collections/EmbeddedLLM/npu-llm-66d692817e6c9509bb8ead58)
+
+| Model | Model Link |
+| --- | --- |
+| Phi-3-mini-4k-instruct | [link](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) |
+| Phi-3-mini-128k-instruct | [link](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct) |
+| Phi-3-medium-4k-instruct | [link](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct) |
+| Phi-3-medium-128k-instruct | [link](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct) |
+
+## Contribution
+We welcome contributions to the verified model list.
\ No newline at end of file
diff --git a/requirements-npu.txt b/requirements-npu.txt
new file mode 100644
index 0000000..dbcb8cf
--- /dev/null
+++ b/requirements-npu.txt
@@ -0,0 +1,3 @@
+intel-npu-acceleration-library
+torch>=2.4
+transformers>=4.42
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 4520ee6..50ce2f9 100644
--- a/setup.py
+++ b/setup.py
@@ -54,6 +54,10 @@ def _is_openvino() -> bool:
return ELLM_TARGET_DEVICE == "openvino"
+def _is_npu() -> bool:
+ return ELLM_TARGET_DEVICE == "npu"
+
+
class ELLMInstallCommand(install):
def run(self):
install.run(self)
@@ -198,6 +202,8 @@ def get_requirements() -> List[str]:
requirements = _read_requirements("requirements-ipex.txt")
elif _is_openvino():
requirements = _read_requirements("requirements-openvino.txt")
+ elif _is_npu():
+ requirements = _read_requirements("requirements-npu.txt")
else:
raise ValueError("Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.")
return requirements
@@ -216,6 +222,8 @@ def get_ellm_version() -> str:
version += "+ipex"
elif _is_openvino():
version += "+openvino"
+ elif _is_npu():
+ version += "+npu"
else:
raise RuntimeError("Unknown runtime environment")
@@ -268,6 +276,7 @@ def get_ellm_version() -> str:
"cuda": ["onnxruntime-genai-cuda==0.3.0rc2"],
"ipex": [],
"openvino": [],
+ "npu": [],
},
dependency_links=dependency_links,
entry_points={
diff --git a/src/embeddedllm/backend/intel_npu_engine.py b/src/embeddedllm/backend/intel_npu_engine.py
new file mode 100644
index 0000000..c245e43
--- /dev/null
+++ b/src/embeddedllm/backend/intel_npu_engine.py
@@ -0,0 +1,268 @@
+import contextlib
+import time
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from typing import AsyncIterator, List, Optional
+
+from loguru import logger
+from PIL import Image
+from transformers import (
+ AutoConfig,
+ PreTrainedTokenizer,
+ PreTrainedTokenizerFast,
+ TextIteratorStreamer,
+)
+
+from threading import Thread
+
+import intel_npu_acceleration_library as npu_lib
+
+from embeddedllm.inputs import PromptInputs
+from embeddedllm.protocol import CompletionOutput, RequestOutput
+from embeddedllm.sampling_params import SamplingParams
+from embeddedllm.backend.base_engine import BaseLLMEngine, _get_and_verify_max_len
+
+RECORD_TIMING = True
+
+
+class NPUEngine(BaseLLMEngine):
+ def __init__(self, model_path: str, vision: bool, device: str = "npu"):
+ self.model_path = model_path
+ self.model_config: AutoConfig = AutoConfig.from_pretrained(
+ self.model_path, trust_remote_code=True
+ )
+ self.device = device
+
+ # model_config is to find out the max length of the model
+ self.max_model_len = _get_and_verify_max_len(
+ hf_config=self.model_config,
+ max_model_len=None,
+ disable_sliding_window=False,
+ sliding_window_len=self.get_hf_config_sliding_window(),
+ )
+
+ logger.info("Model Context Length: " + str(self.max_model_len))
+
+ try:
+ logger.info("Attempt to load fast tokenizer")
+ self.tokenizer = PreTrainedTokenizerFast.from_pretrained(self.model_path)
+ except Exception:
+ logger.info("Attempt to load slower tokenizer")
+ self.tokenizer = PreTrainedTokenizer.from_pretrained(self.model_path)
+
+ self.model = npu_lib.NPUModelForCausalLM.from_pretrained(
+ self.model_path,
+ torch_dtype="auto",
+ dtype=npu_lib.int4,
+ trust_remote_code=True,
+ export=False
+ )
+
+ logger.info("Model loaded")
+ self.tokenizer_stream = TextIteratorStreamer(
+ self.tokenizer, skip_prompt=True, skip_special_tokens=True
+ )
+ logger.info("Tokenizer created")
+
+ self.vision = vision
+
+ # if self.vision:
+ # self.onnx_processor = self.model.create_multimodal_processor()
+ # self.processor = AutoImageProcessor.from_pretrained(
+ # self.model_path, trust_remote_code=True
+ # )
+ # print(dir(self.processor))
+
+ async def generate_vision(
+ self,
+ inputs: PromptInputs,
+ sampling_params: SamplingParams,
+ request_id: str,
+ stream: bool = True,
+ ) -> AsyncIterator[RequestOutput]:
+ raise NotImplementedError(f"generate_vision yet to be implemented.")
+
+ async def generate(
+ self,
+ inputs: PromptInputs,
+ sampling_params: SamplingParams,
+ request_id: str,
+ stream: bool = True,
+ ) -> AsyncIterator[RequestOutput]:
+ """Generate outputs for a request.
+
+ Generate outputs for a request. This method is a coroutine. It adds the
+ request into the waiting queue of the LLMEngine and streams the outputs
+ from the LLMEngine to the caller.
+
+ """
+
+ prompt_text = inputs["prompt"]
+ input_token_length = None
+ input_tokens = None # for text only use case
+ # logger.debug("inputs: " + prompt_text)
+
+ input_tokens = self.tokenizer.encode(prompt_text, return_tensors="pt")
+ # logger.debug(f"input_tokens: {input_tokens}")
+ input_token_length = len(input_tokens[0])
+
+ max_tokens = sampling_params.max_tokens
+
+ assert input_token_length is not None
+
+ if input_token_length + max_tokens > self.max_model_len:
+ raise ValueError("Exceed Context Length")
+
+ generation_options = {
+ name: getattr(sampling_params, name)
+ for name in [
+ "do_sample",
+ # "max_length",
+ "max_new_tokens",
+ "min_length",
+ "top_p",
+ "top_k",
+ "temperature",
+ "repetition_penalty",
+ ]
+ if hasattr(sampling_params, name)
+ }
+ generation_options["max_length"] = self.max_model_len
+ generation_options["input_ids"] = input_tokens.clone()
+ # generation_options["input_ids"] = input_tokens.clone().to(self.device)
+ generation_options["max_new_tokens"] = max_tokens
+ print(generation_options)
+
+ token_list: List[int] = []
+ output_text: str = ""
+ if stream:
+ generation_options["streamer"] = self.tokenizer_stream
+ if RECORD_TIMING:
+ started_timestamp = time.time()
+ first_token_timestamp = 0
+ first = True
+ new_tokens = []
+ try:
+ thread = Thread(target=self.model.generate, kwargs=generation_options)
+ started_timestamp = time.time()
+ first_token_timestamp = None
+ thread.start()
+ output_text = ""
+ first = True
+ for new_text in self.tokenizer_stream:
+ if new_text == "":
+ continue
+ if RECORD_TIMING:
+ if first:
+ first_token_timestamp = time.time()
+ first = False
+ # logger.debug(f"new text: {new_text}")
+ output_text += new_text
+ token_list = self.tokenizer.encode(output_text, return_tensors="pt")
+
+ output = RequestOutput(
+ request_id=request_id,
+ prompt=prompt_text,
+ prompt_token_ids=input_tokens[0],
+ finished=False,
+ outputs=[
+ CompletionOutput(
+ index=0,
+ text=output_text,
+ token_ids=token_list[0],
+ cumulative_logprob=-1.0,
+ )
+ ],
+ )
+ yield output
+ # logits = generator.get_output("logits")
+ # print(logits)
+ if RECORD_TIMING:
+ new_tokens = token_list[0]
+
+ yield RequestOutput(
+ request_id=request_id,
+ prompt=prompt_text,
+ prompt_token_ids=input_tokens[0],
+ finished=True,
+ outputs=[
+ CompletionOutput(
+ index=0,
+ text=output_text,
+ token_ids=token_list[0],
+ cumulative_logprob=-1.0,
+ finish_reason="stop",
+ )
+ ],
+ )
+ if RECORD_TIMING:
+ prompt_time = first_token_timestamp - started_timestamp
+ run_time = time.time() - first_token_timestamp
+ logger.info(
+ f"Prompt length: {len(input_tokens[0])}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens[0])/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps"
+ )
+
+ except Exception as e:
+ logger.error(str(e))
+
+ error_output = RequestOutput(
+ prompt=inputs,
+ prompt_token_ids=input_tokens,
+ finished=True,
+ request_id=request_id,
+ outputs=[
+ CompletionOutput(
+ index=0,
+ text=output_text,
+ token_ids=token_list,
+ cumulative_logprob=-1.0,
+ finish_reason="error",
+ stop_reason=str(e),
+ )
+ ],
+ )
+ yield error_output
+ else:
+ try:
+ token_list = self.model.generate(**generation_options)[0]
+
+ output_text = self.tokenizer.decode(
+ token_list[input_token_length:], skip_special_tokens=True
+ )
+
+ yield RequestOutput(
+ request_id=request_id,
+ prompt=prompt_text,
+ prompt_token_ids=input_tokens[0],
+ finished=True,
+ outputs=[
+ CompletionOutput(
+ index=0,
+ text=output_text,
+ token_ids=token_list,
+ cumulative_logprob=-1.0,
+ finish_reason="stop",
+ )
+ ],
+ )
+
+ except Exception as e:
+ logger.error(str(e))
+
+ error_output = RequestOutput(
+ prompt=prompt_text,
+ prompt_token_ids=input_tokens[0],
+ finished=True,
+ request_id=request_id,
+ outputs=[
+ CompletionOutput(
+ index=0,
+ text=output_text,
+ token_ids=token_list,
+ cumulative_logprob=-1.0,
+ finish_reason="error",
+ stop_reason=str(e),
+ )
+ ],
+ )
+ yield error_output
\ No newline at end of file
diff --git a/src/embeddedllm/engine.py b/src/embeddedllm/engine.py
index e2c5a9d..b341472 100644
--- a/src/embeddedllm/engine.py
+++ b/src/embeddedllm/engine.py
@@ -56,6 +56,22 @@ def __init__(self, model_path: str, vision: bool, device: str = "xpu", backend:
self.engine = OnnxruntimeEngine(self.model_path, self.vision, self.device)
logger.info(f"Initializing onnxruntime backend ({backend.upper()}): OnnxruntimeEngine")
+
+ elif self.backend == "npu":
+ assert self.device == "npu", f"To run npu backend, device must be npu."
+ processor = get_processor_type()
+ if(processor == "Intel"):
+ from embeddedllm.backend.intel_npu_engine import NPUEngine
+
+ self.engine = NPUEngine(self.model_path, self.vision, self.device)
+ logger.info(f"Initializing Intel npu backend (NPU): NPUEngine")
+
+ elif(processor == "AMD"):
+ raise SystemError(f"NPU support on AMD platform is not supported yet.")
+
+ else:
+ raise SystemError(f"Unknown processor is not supported.")
+
elif self.backend == "cpu":
assert self.device == "cpu", f"To run `cpu` backend, `device` must be `cpu`."
processor = get_processor_type()
@@ -80,7 +96,7 @@ def __init__(self, model_path: str, vision: bool, device: str = "xpu", backend:
else:
raise ValueError(
- f"EmbeddedLLMEngine only supports `cpu`, `ipex`, `cuda`, `openvino` and `directml`."
+ f"EmbeddedLLMEngine only supports `cpu`, `npu`, `ipex`, `cuda`, `openvino` and `directml`."
)
self.tokenizer = self.engine.tokenizer
diff --git a/src/embeddedllm/entrypoints/modelui.py b/src/embeddedllm/entrypoints/modelui.py
index 9c82355..81cb681 100644
--- a/src/embeddedllm/entrypoints/modelui.py
+++ b/src/embeddedllm/entrypoints/modelui.py
@@ -20,7 +20,7 @@ def get_embeddedllm_backend():
version = importlib.metadata.version("embeddedllm")
# Use regex to extract the backend
- match = re.search(r"\+(directml|cpu|cuda|ipex|openvino)$", version)
+ match = re.search(r"\+(directml|npu|cpu|cuda|ipex|openvino)$", version)
if match:
backend = match.group(1)
@@ -260,6 +260,41 @@ class ModelCard(BaseModel):
),
}
+npu_model_dict_list = {
+ "microsoft/Phi-3-mini-4k-instruct": ModelCard(
+ hf_url="https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/tree/main/",
+ repo_id="microsoft/Phi-3-mini-4k-instruct",
+ model_name="Phi-3-mini-4k-instruct",
+ subfolder=".",
+ repo_type="model",
+ context_length=4096,
+ ),
+ "microsoft/Phi-3-mini-128k-instruct": ModelCard(
+ hf_url="https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/tree/main",
+ repo_id="microsoft/Phi-3-mini-128k-instruct",
+ model_name="Phi-3-mini-128k-instruct",
+ subfolder=".",
+ repo_type="model",
+ context_length=131072,
+ ),
+ "microsoft/Phi-3-medium-4k-instruct": ModelCard(
+ hf_url="https://huggingface.co/microsoft/Phi-3-medium-4k-instruct/tree/main",
+ repo_id="microsoft/Phi-3-medium-4k-instruct",
+ model_name="Phi-3-medium-4k-instruct",
+ subfolder=".",
+ repo_type="model",
+ context_length=4096,
+ ),
+ "microsoft/Phi-3-medium-128k-instruct": ModelCard(
+ hf_url="https://huggingface.co/microsoft/Phi-3-medium-128k-instruct/tree/main",
+ repo_id="microsoft/Phi-3-medium-128k-instruct",
+ model_name="Phi-3-medium-128k-instruct",
+ subfolder=".",
+ repo_type="model",
+ context_length=131072,
+ ),
+}
+
ipex_model_dict_list = {
"microsoft/Phi-3-mini-4k-instruct": ModelCard(
hf_url="https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/tree/main/",
@@ -507,6 +542,11 @@ def compute_memory_size(repo_id, path_in_repo, repo_type: str = "model"):
repo_id=v.repo_id, path_in_repo=v.subfolder, repo_type=v.repo_type
)
+for k, v in npu_model_dict_list.items():
+ v.size = compute_memory_size(
+ repo_id=v.repo_id, path_in_repo=v.subfolder, repo_type=v.repo_type
+ )
+
for k, v in ipex_model_dict_list.items():
v.size = compute_memory_size(
repo_id=v.repo_id, path_in_repo=v.subfolder, repo_type=v.repo_type
@@ -603,6 +643,9 @@ def update_model_list(engine_type):
if engine_type == "DirectML":
models = sorted(list(dml_model_dict_list.keys()))
models_pandas = convert_to_dataframe(dml_model_dict_list)
+ elif backend == "npu":
+ models = sorted(list(npu_model_dict_list.keys()))
+ models_pandas = convert_to_dataframe(npu_model_dict_list)
elif backend == "ipex":
models = sorted(list(ipex_model_dict_list.keys()))
models_pandas = convert_to_dataframe(ipex_model_dict_list)
@@ -631,6 +674,8 @@ def deploy_model(engine_type, model_name, port_number):
if engine_type == "DirectML":
llm_model_card = dml_model_dict_list[model_name]
+ elif backend == "npu":
+ llm_model_card = npu_model_dict_list[model_name]
elif backend == "ipex":
llm_model_card = ipex_model_dict_list[model_name]
elif backend == "openvino":
@@ -654,7 +699,9 @@ def deploy_model(engine_type, model_name, port_number):
model_path = llm_model_card.repo_id
print("Model path:", model_path)
- if engine_type == "Ipex":
+ if engine_type == "NPU":
+ device = "npu"
+ elif engine_type == "Ipex":
device = "xpu"
elif engine_type == "OpenVino":
device = "gpu"
@@ -718,6 +765,8 @@ def download_model(engine_type, model_name):
if engine_type == "DirectML":
llm_model_card = dml_model_dict_list[model_name]
+ elif backend == "npu":
+ llm_model_card = npu_model_dict_list[model_name]
elif backend == "ipex":
llm_model_card = ipex_model_dict_list[model_name]
elif backend == "openvino":
@@ -771,6 +820,8 @@ def main():
if backend == "directml":
default_value = "DirectML"
+ elif backend == "npu":
+ default_value = "NPU"
elif backend == "ipex":
default_value = "Ipex"
elif backend == "openvino":