Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] What is the value range for inputs for batch_generator and more #1523

Closed
rakesh-krishna opened this issue Oct 26, 2023 · 2 comments

Comments

@rakesh-krishna
Copy link

Hi I am new to ctranslate.
I am writing a fastapi server for llm inference with this code as a base.
I want to write a input pydantic model to validate the inputs ( so far it is static but I want to change this ).
Went through the docs and tried to read the code but I am not that familiar with cpp so was not able to figure it out.
So if anyone have any ideas on the range for inputs let me know.

thanks in advance 😄

from typing import List
import ctranslate2
import transformers
from fastapi import FastAPI
from pydantic import BaseModel
import asyncio
import uvicorn
from asyncio.locks import Lock

app = FastAPI()

generator = ctranslate2.Generator("gpt2_ct2",device="cuda")
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")


# Define the request and response models for the FastAPI endpoint.
class TranslationRequest(BaseModel):
    text: str


class TranslationResponse(BaseModel):
    translations: List[str]


# Initialize a queue for batching translation requests and a dictionary for results.
request_queue = asyncio.Queue()
results = {}
MAX_BATCH_SIZE = 100
TIMEOUT = 0.1  # Time to wait for accumulating enough batch items.
batch_ready_event = asyncio.Event()
results_lock = Lock()  # Lock to synchronize access to the results dictionary.


async def batch_processor():
    """Asynchronously process translation requests in batches."""
    while True:
        try:
            await asyncio.wait_for(batch_ready_event.wait(), timeout=TIMEOUT)
        except asyncio.TimeoutError:
            pass

        # Accumulate translation requests for batching.
        batched_items = []
        identifiers = []
        while not request_queue.empty() and len(batched_items) < MAX_BATCH_SIZE:
            uid, text = await request_queue.get()
            batched_items.append(text)
            identifiers.append(uid)

        # If there are items to translate, process them.
        if batched_items:
            print(f"Translating a batch of {len(batched_items)} items.")
            try:
                translations = translate_batch(batched_items)
                async with results_lock:
                    for uid, translation in zip(identifiers, translations):
                        if uid in results:
                            results[uid]["translation"] = translation
                            event = results[uid]["event"]
                            event.set()
            except Exception as e:
                # Handle translation errors.
                print(f"Error during translation: {e}")

        # Reset the event to wait for the next batch.
        batch_ready_event.clear()


@app.on_event("startup")
async def startup_event():
    """On server startup, initialize the batch processor."""
    asyncio.create_task(batch_processor())


@app.post("/generate", response_model=TranslationResponse)
async def translate_endpoint(request: TranslationRequest):
    """Endpoint to handle translation requests."""
    result_event = asyncio.Event()
    unique_id = str(id(result_event))
    async with results_lock:
        results[unique_id] = {"event": result_event, "translation": None}
    await request_queue.put((unique_id, request.text))

    # Trigger the batch_ready_event if the queue size reaches the defined threshold.
    if request_queue.qsize() >= MAX_BATCH_SIZE:
        batch_ready_event.set()

    # Wait for the translation result.
    await result_event.wait()
    async with results_lock:
        translation = results.pop(unique_id, {}).get("translation", "")
    return {"results": [translation]}


def translate_batch(texts: List[str]) -> List[str]:
    """Translate a batch of texts using the pre-initialized translator and tokenizer."""
    # Tokenize source texts.
    sources = [tokenizer.convert_ids_to_tokens(tokenizer.encode(text)) for text in texts]
    batch_results = generator.generate_batch(sources, max_length=30, sampling_topk=10)
    translations = []
    for result in batch_results:
        target = result.sequences[0][1:]
        translations.append(tokenizer.decode(tokenizer.convert_tokens_to_ids(target)))
    return translations


if __name__ == '__main__':
    # Start the FastAPI server.
    uvicorn.run(app=app, host='0.0.0.0', port=8000)
@minhthuc2502
Copy link
Collaborator

According to my knowledge, There isn't any limit for the range of input now. You can handle the range input by yourself. The only restrict when you have a very large range is the limit of the hardware.

@rakesh-krishna
Copy link
Author

@minhthuc2502 thanks I will experiment on the hardware and see.
To see if there were some error message I gave 100000 and it core dumped.
Know I get it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants