-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add NPU Engine #31
Open
szeyu
wants to merge
11
commits into
main
Choose a base branch
from
szeyu-npu-1
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add NPU Engine #31
Changes from 5 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
014a411
update npu models and engine setup
szeyu 3e92ce5
Update README.md
szeyu 736ea85
Update README.md
szeyu 0bacaa7
fix the typo of __init__
szeyu 2d730a3
Update modelui.py
szeyu c5aee57
Merge branch 'szeyu-patch-2' into szeyu-npu-1
szeyu abfab05
update gitignore
szeyu e77cd9d
Merge branch 'main' into szeyu-npu-1
szeyu d7586d4
Renamed to npu_engine to intel_npu_engine to specify that it is intel…
szeyu e0d320f
Add support for Intel NPU backend and handle unsupported processors
szeyu 077562d
Merge branch 'main' into szeyu-npu-1
szeyu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Model Powered by NPU-LLM | ||
|
||
## Verified Models | ||
Verified models can be found from EmbeddedLLM NPU-LLM model collections | ||
* EmbeddedLLM NPU-LLM Model collections: [link](https://huggingface.co/collections/EmbeddedLLM/npu-llm-66d692817e6c9509bb8ead58) | ||
|
||
| Model | Model Link | | ||
| --- | --- | | ||
| Phi-3-mini-4k-instruct | [link](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) | | ||
| Phi-3-mini-128k-instruct | [link](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct) | | ||
| Phi-3-medium-4k-instruct | [link](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct) | | ||
| Phi-3-medium-128k-instruct | [link](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct) | | ||
|
||
## Contribution | ||
We welcome contributions to the verified model list. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
intel-npu-acceleration-library | ||
torch>=2.4 | ||
transformers>=4.42 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,268 @@ | ||
import contextlib | ||
import time | ||
from pathlib import Path | ||
from tempfile import TemporaryDirectory | ||
from typing import AsyncIterator, List, Optional | ||
|
||
from loguru import logger | ||
from PIL import Image | ||
from transformers import ( | ||
AutoConfig, | ||
PreTrainedTokenizer, | ||
PreTrainedTokenizerFast, | ||
TextIteratorStreamer, | ||
) | ||
|
||
from threading import Thread | ||
|
||
import intel_npu_acceleration_library as npu_lib | ||
|
||
from embeddedllm.inputs import PromptInputs | ||
from embeddedllm.protocol import CompletionOutput, RequestOutput | ||
from embeddedllm.sampling_params import SamplingParams | ||
from embeddedllm.backend.base_engine import BaseLLMEngine, _get_and_verify_max_len | ||
|
||
RECORD_TIMING = True | ||
|
||
|
||
class NPUEngine(BaseLLMEngine): | ||
def __init__(self, model_path: str, vision: bool, device: str = "npu"): | ||
self.model_path = model_path | ||
self.model_config: AutoConfig = AutoConfig.from_pretrained( | ||
self.model_path, trust_remote_code=True | ||
) | ||
self.device = device | ||
|
||
# model_config is to find out the max length of the model | ||
self.max_model_len = _get_and_verify_max_len( | ||
hf_config=self.model_config, | ||
max_model_len=None, | ||
disable_sliding_window=False, | ||
sliding_window_len=self.get_hf_config_sliding_window(), | ||
) | ||
|
||
logger.info("Model Context Length: " + str(self.max_model_len)) | ||
|
||
try: | ||
logger.info("Attempt to load fast tokenizer") | ||
self.tokenizer = PreTrainedTokenizerFast.from_pretrained(self.model_path) | ||
except Exception: | ||
logger.info("Attempt to load slower tokenizer") | ||
self.tokenizer = PreTrainedTokenizer.from_pretrained(self.model_path) | ||
|
||
self.model = npu_lib.NPUModelForCausalLM.from_pretrained( | ||
self.model_path, | ||
torch_dtype="auto", | ||
dtype=npu_lib.int4, | ||
trust_remote_code=True, | ||
export=False | ||
) | ||
|
||
logger.info("Model loaded") | ||
self.tokenizer_stream = TextIteratorStreamer( | ||
self.tokenizer, skip_prompt=True, skip_special_tokens=True | ||
) | ||
logger.info("Tokenizer created") | ||
|
||
self.vision = vision | ||
|
||
# if self.vision: | ||
# self.onnx_processor = self.model.create_multimodal_processor() | ||
# self.processor = AutoImageProcessor.from_pretrained( | ||
# self.model_path, trust_remote_code=True | ||
# ) | ||
# print(dir(self.processor)) | ||
|
||
async def generate_vision( | ||
self, | ||
inputs: PromptInputs, | ||
sampling_params: SamplingParams, | ||
request_id: str, | ||
stream: bool = True, | ||
) -> AsyncIterator[RequestOutput]: | ||
raise NotImplementedError(f"generate_vision yet to be implemented.") | ||
|
||
async def generate( | ||
self, | ||
inputs: PromptInputs, | ||
sampling_params: SamplingParams, | ||
request_id: str, | ||
stream: bool = True, | ||
) -> AsyncIterator[RequestOutput]: | ||
"""Generate outputs for a request. | ||
|
||
Generate outputs for a request. This method is a coroutine. It adds the | ||
request into the waiting queue of the LLMEngine and streams the outputs | ||
from the LLMEngine to the caller. | ||
|
||
""" | ||
|
||
prompt_text = inputs["prompt"] | ||
input_token_length = None | ||
input_tokens = None # for text only use case | ||
# logger.debug("inputs: " + prompt_text) | ||
|
||
input_tokens = self.tokenizer.encode(prompt_text, return_tensors="pt") | ||
# logger.debug(f"input_tokens: {input_tokens}") | ||
input_token_length = len(input_tokens[0]) | ||
|
||
max_tokens = sampling_params.max_tokens | ||
|
||
assert input_token_length is not None | ||
|
||
if input_token_length + max_tokens > self.max_model_len: | ||
raise ValueError("Exceed Context Length") | ||
|
||
generation_options = { | ||
name: getattr(sampling_params, name) | ||
for name in [ | ||
"do_sample", | ||
# "max_length", | ||
"max_new_tokens", | ||
"min_length", | ||
"top_p", | ||
"top_k", | ||
"temperature", | ||
"repetition_penalty", | ||
] | ||
if hasattr(sampling_params, name) | ||
} | ||
generation_options["max_length"] = self.max_model_len | ||
generation_options["input_ids"] = input_tokens.clone() | ||
# generation_options["input_ids"] = input_tokens.clone().to(self.device) | ||
generation_options["max_new_tokens"] = max_tokens | ||
print(generation_options) | ||
|
||
token_list: List[int] = [] | ||
output_text: str = "" | ||
if stream: | ||
generation_options["streamer"] = self.tokenizer_stream | ||
if RECORD_TIMING: | ||
started_timestamp = time.time() | ||
first_token_timestamp = 0 | ||
first = True | ||
new_tokens = [] | ||
try: | ||
thread = Thread(target=self.model.generate, kwargs=generation_options) | ||
started_timestamp = time.time() | ||
first_token_timestamp = None | ||
thread.start() | ||
output_text = "" | ||
first = True | ||
for new_text in self.tokenizer_stream: | ||
if new_text == "": | ||
continue | ||
if RECORD_TIMING: | ||
if first: | ||
first_token_timestamp = time.time() | ||
first = False | ||
# logger.debug(f"new text: {new_text}") | ||
output_text += new_text | ||
token_list = self.tokenizer.encode(output_text, return_tensors="pt") | ||
|
||
output = RequestOutput( | ||
request_id=request_id, | ||
prompt=prompt_text, | ||
prompt_token_ids=input_tokens[0], | ||
finished=False, | ||
outputs=[ | ||
CompletionOutput( | ||
index=0, | ||
text=output_text, | ||
token_ids=token_list[0], | ||
cumulative_logprob=-1.0, | ||
) | ||
], | ||
) | ||
yield output | ||
# logits = generator.get_output("logits") | ||
# print(logits) | ||
if RECORD_TIMING: | ||
new_tokens = token_list[0] | ||
|
||
yield RequestOutput( | ||
request_id=request_id, | ||
prompt=prompt_text, | ||
prompt_token_ids=input_tokens[0], | ||
finished=True, | ||
outputs=[ | ||
CompletionOutput( | ||
index=0, | ||
text=output_text, | ||
token_ids=token_list[0], | ||
cumulative_logprob=-1.0, | ||
finish_reason="stop", | ||
) | ||
], | ||
) | ||
if RECORD_TIMING: | ||
prompt_time = first_token_timestamp - started_timestamp | ||
run_time = time.time() - first_token_timestamp | ||
logger.info( | ||
f"Prompt length: {len(input_tokens[0])}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens[0])/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps" | ||
) | ||
|
||
except Exception as e: | ||
logger.error(str(e)) | ||
|
||
error_output = RequestOutput( | ||
prompt=inputs, | ||
prompt_token_ids=input_tokens, | ||
finished=True, | ||
request_id=request_id, | ||
outputs=[ | ||
CompletionOutput( | ||
index=0, | ||
text=output_text, | ||
token_ids=token_list, | ||
cumulative_logprob=-1.0, | ||
finish_reason="error", | ||
stop_reason=str(e), | ||
) | ||
], | ||
) | ||
yield error_output | ||
else: | ||
try: | ||
token_list = self.model.generate(**generation_options)[0] | ||
|
||
output_text = self.tokenizer.decode( | ||
token_list[input_token_length:], skip_special_tokens=True | ||
) | ||
|
||
yield RequestOutput( | ||
request_id=request_id, | ||
prompt=prompt_text, | ||
prompt_token_ids=input_tokens[0], | ||
finished=True, | ||
outputs=[ | ||
CompletionOutput( | ||
index=0, | ||
text=output_text, | ||
token_ids=token_list, | ||
cumulative_logprob=-1.0, | ||
finish_reason="stop", | ||
) | ||
], | ||
) | ||
|
||
except Exception as e: | ||
logger.error(str(e)) | ||
|
||
error_output = RequestOutput( | ||
prompt=prompt_text, | ||
prompt_token_ids=input_tokens[0], | ||
finished=True, | ||
request_id=request_id, | ||
outputs=[ | ||
CompletionOutput( | ||
index=0, | ||
text=output_text, | ||
token_ids=token_list, | ||
cumulative_logprob=-1.0, | ||
finish_reason="error", | ||
stop_reason=str(e), | ||
) | ||
], | ||
) | ||
yield error_output |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you rename
npu_engine.py
intointel_npu_engine.py
as this is NPU code for Intel only?Do DM me on Whatsapp to discuss about this if you think otherwise.