Skip to content

Commit

Permalink
[Fix] Support fetching images when using worker engine (#574)
Browse files Browse the repository at this point in the history
The previous implementation relied on the HTMLImageElement constuctor
which is not available in worker contexts.
  • Loading branch information
dstoc authored Sep 26, 2024
1 parent dc2d5ea commit fde1777
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 31 deletions.
31 changes: 21 additions & 10 deletions examples/vision-model/src/vision_model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ function setLabel(id: string, text: string) {
label.innerText = text;
}

const USE_WEB_WORKER = false;

const proxyUrl = "https://cors-anywhere.herokuapp.com/";
const url_https_street = "https://www.ilankelman.org/stopsigns/australia.jpg";
const url_https_tree = "https://www.ilankelman.org/sunset.jpg";
Expand All @@ -23,16 +25,25 @@ async function main() {
setLabel("init-label", report.text);
};
const selectedModel = "Phi-3.5-vision-instruct-q4f16_1-MLC";
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
selectedModel,
{
initProgressCallback: initProgressCallback,
logLevel: "INFO", // specify the log level
},
{
context_window_size: 6144,
},
);

const engineConfig: webllm.MLCEngineConfig = {
initProgressCallback: initProgressCallback,
logLevel: "INFO", // specify the log level
};
const chatOpts = {
context_window_size: 6144,
};

const engine: webllm.MLCEngineInterface = USE_WEB_WORKER
? await webllm.CreateWebWorkerMLCEngine(
new Worker(new URL("./worker.ts", import.meta.url), {
type: "module",
}),
selectedModel,
engineConfig,
chatOpts,
)
: await webllm.CreateMLCEngine(selectedModel, engineConfig, chatOpts);

// 1. Single image input (with choices)
const messages: webllm.ChatCompletionMessageParam[] = [
Expand Down
32 changes: 11 additions & 21 deletions src/support.ts
Original file line number Diff line number Diff line change
Expand Up @@ -411,28 +411,18 @@ export const IMAGE_EMBED_SIZE = 1921;
/**
* Given a url, get the image data. The url can either start with `http` or `data:image`.
*/
export function getImageDataFromURL(url: string): Promise<ImageData> {
return new Promise((resolve, reject) => {
// Converts img to any, and later `as CanvasImageSource`, otherwise build complains
const img: any = new Image();
img.crossOrigin = "anonymous"; // Important for CORS
img.onload = () => {
const canvas = document.createElement("canvas");
const ctx = canvas.getContext("2d");
if (!ctx) {
reject(new Error("Could not get 2d context"));
return;
}
canvas.width = img.width;
canvas.height = img.height;
ctx.drawImage(img as CanvasImageSource, 0, 0);
export async function getImageDataFromURL(url: string): Promise<ImageData> {
const response = await fetch(url, { mode: "cors" });
const img = await createImageBitmap(await response.blob());
const canvas = new OffscreenCanvas(img.width, img.height);
const ctx = canvas.getContext("2d");
if (!ctx) {
throw new Error("Could not get 2d context");
}
ctx.drawImage(img, 0, 0);

const imageData = ctx.getImageData(0, 0, img.width, img.height);
resolve(imageData);
};
img.onerror = () => reject(new Error("Failed to load image"));
img.src = url;
});
const imageData = ctx.getImageData(0, 0, img.width, img.height);
return imageData;
}

/**
Expand Down

0 comments on commit fde1777

Please sign in to comment.