Skip to content

Commit

Permalink
Add mask download
Browse files Browse the repository at this point in the history
  • Loading branch information
ihmily committed Jan 5, 2024
1 parent 2139ff5 commit 8d566e3
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 63 deletions.
73 changes: 56 additions & 17 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# -*- coding: utf-8 -*-

import sys
from fastapi import FastAPI, File, UploadFile, Form, Response
import os
import uuid
from datetime import datetime
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi import Request
import requests
import cv2
Expand All @@ -23,6 +26,12 @@
default_model_info = model_paths[default_model]
loaded_models = {default_model: pipeline(default_model_info['task'], model=default_model_info['path'])}

UPLOAD_FOLDER = "./upload"
OUTPUT_FOLDER = "./output"

os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)


class ModelLoader:
def __init__(self):
Expand Down Expand Up @@ -62,13 +71,26 @@ async def matting(image: UploadFile = File(...), model: str = Form(default=defau

selected_model = model_loader.load_model(model)

filename = uuid.uuid4()
original_image_filename = f"original_{filename}.png"
image_filename = f"image_{filename}.png"
mask_filename = f"mask_{filename}.png"

cv2.imwrite(os.path.join(UPLOAD_FOLDER, original_image_filename), img)

result = selected_model(img)

output_img = result[OutputKeys.OUTPUT_IMG]
cv2.imwrite(os.path.join(OUTPUT_FOLDER, image_filename), result[OutputKeys.OUTPUT_IMG])
cv2.imwrite(os.path.join(OUTPUT_FOLDER, mask_filename), result['output_img'][:, :, 3])

output_bytes = cv2.imencode('.png', output_img)[1].tobytes()
response_data = {
"result_image_url": f"/output/{image_filename}",
"mask_image_url": f"/output/{mask_filename}",
"original_image_size": {"width": img.shape[1], "height": img.shape[0]},
"generation_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}

return Response(content=output_bytes, media_type='image/png')
return response_data


@app.post("/matting/url")
Expand All @@ -77,34 +99,50 @@ async def matting_url(request: Request, model: str = Form(default=default_model,
json_data = await request.json()
image_url = json_data.get("image_url")
except Exception as e:
return {"content": f"Error parsing JSON data: {str(e)}"}, 400
raise HTTPException(status_code=400, detail=f"Error parsing JSON data: {str(e)}")

if not image_url:
return {"content": "Image URL is required"}, 400

response = requests.get(image_url)
if response.status_code != 200:
return {"content": "Failed to fetch image from URL"}, 400
raise HTTPException(status_code=400, detail="Image URL is required")

img_array = np.frombuffer(response.content, dtype=np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
try:
response = requests.get(image_url)
response.raise_for_status()
img_array = np.frombuffer(response.content, dtype=np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
except requests.RequestException as e:
raise HTTPException(status_code=400, detail=f"Failed to fetch image from URL: {str(e)}")

if model not in model_paths:
return {"content": "Invalid model selection"}, 400
raise HTTPException(status_code=400, detail="Invalid model selection")

selected_model = model_loader.load_model(model)

filename = uuid.uuid4()
original_image_filename = f"original_{filename}.png"
image_filename = f"image_{filename}.png"
mask_filename = f"mask_{filename}.png"

cv2.imwrite(os.path.join(UPLOAD_FOLDER, original_image_filename), img)

result = selected_model(img)

output_img = result[OutputKeys.OUTPUT_IMG]
cv2.imwrite(os.path.join(OUTPUT_FOLDER, image_filename), result[OutputKeys.OUTPUT_IMG])
cv2.imwrite(os.path.join(OUTPUT_FOLDER, mask_filename), result['output_img'][:, :, 3])

output_bytes = cv2.imencode('.png', output_img)[1].tobytes()
response_data = {
"result_image_url": f"/output/{image_filename}",
"mask_image_url": f"/output/{mask_filename}",
"original_image_size": {"width": img.shape[1], "height": img.shape[0]},
"generation_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}

return Response(content=output_bytes, media_type='image/png')
return response_data


templates = Jinja2Templates(directory="web")
app.mount("/static", StaticFiles(directory="web/static"), name="static")
app.mount("/static", StaticFiles(directory="./web/static"), name="static")
app.mount("/output", StaticFiles(directory="./output"), name="output")
app.mount("/upload", StaticFiles(directory="./upload"), name="upload")


@app.get("/")
Expand All @@ -115,5 +153,6 @@ async def read_index(request: Request):

if __name__ == "__main__":
import uvicorn

defult_bind_host = "0.0.0.0" if sys.platform != "win32" else "127.0.0.1"
uvicorn.run(app, host=defult_bind_host, port=8000)
63 changes: 17 additions & 46 deletions web/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,54 +4,13 @@
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Image Matting</title>
<style>
body {
font-family: 'Arial', sans-serif;
background-color: #f4f4f4;
margin: 0;
padding: 0;
}

header {
background-color: #333;
color: white;
padding: 1em;
text-align: center;
}

main {
max-width: 800px;
margin: 2em auto;
background-color: white;
padding: 2em;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
text-align: center;
}

#upload-form {
text-align: center;
margin-bottom: 2em;
}

#images-container {
display: flex;
justify-content: space-between;
align-items: center;
}

#original-img,
#result-img {
max-width: 45%;
height: auto;
border: 1px solid #ccc;
}
</style>
<title>Simple Image Matting</title>
<link rel="stylesheet" href="./static/css/style.css">
</head>

<body>
<header>
<h1>Image Matting</h1>
<h1>Simple Image Matting</h1>
</header>
<a href="https://github.com/ihmily/image-matting"><img style="position: absolute; top: 0; right: 0; border: 0;" decoding="async" width="149" height="149" src="/static/images/forkme_right_gray_6d6d6d.png" class="attachment-full size-full" alt="Fork me on GitHub" loading="lazy" data-recalc-dims="1"></a>
<main>
Expand All @@ -71,10 +30,15 @@ <h1>Image Matting</h1>

<div id="images-container">
<img id="original-img" alt=" " src="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7">
<img id="mask-img" alt=" " src="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7">
<img id="result-img" alt=" " src="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7">
</div>
</main>

<footer>
<p>&copy; 2024 Hmily. All rights reserved.</p>
</footer>

<script>
document.getElementById("image-upload").addEventListener("change", function() {
const inputElement = this;
Expand All @@ -84,6 +48,7 @@ <h1>Image Matting</h1>
const originalImgElement = document.getElementById("original-img");
originalImgElement.src = URL.createObjectURL(file);
document.getElementById("result-img").src = "data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7";
document.getElementById("mask-img").src = "data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7";
}
});

Expand All @@ -106,9 +71,15 @@ <h1>Image Matting</h1>
});

if (response.ok) {
const resultBlob = await response.blob();
const responseData = await response.json();

const resultImgElement = document.getElementById("result-img");
resultImgElement.src = URL.createObjectURL(resultBlob);
resultImgElement.src = responseData.result_image_url;

const maskImgElement = document.getElementById("mask-img");
maskImgElement.src = responseData.mask_image_url;

console.log(`Matting successful!\nOriginal Image Size: ${responseData.original_image_size.width} x ${responseData.original_image_size.height}\nGeneration Time: ${responseData.generation_time}`);
} else {
alert("Matting failed. Please try again.");
}
Expand Down
53 changes: 53 additions & 0 deletions web/static/css/style.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
body {
font-family: 'Arial', sans-serif;
background-color: #f4f4f4;
margin: 0;
padding: 0;
}

header {
background-color: #333;
color: white;
padding: 1em;
text-align: center;
}

main {
max-width: 1200px;
margin: 2em auto;
background-color: white;
padding: 2em;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
text-align: center;
}

#upload-form {
text-align: center;
margin-bottom: 2em;
}

#images-container {
display: flex;
justify-content: space-between;
align-items: center;
margin-top: 2em; /* Add spacing */
}

#original-img,
#mask-img,
#result-img {
max-width: 30%;
height: auto;
border: 1px solid #ccc;
margin: 0.5em; /* Add spacing */
}

footer {
background-color: #333;
color: white;
text-align: center;
padding: 1em;
position: fixed;
bottom: 0;
width: 100%;
}

0 comments on commit 8d566e3

Please sign in to comment.