-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
217 lines (173 loc) · 6.74 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import streamlit as st
import tensorflow as tf
import time
from PIL import Image, ImageOps
import numpy as np
# import json
# import requests
# from streamlit_lottie import st_lottie
# Load the saved model
model = tf.keras.models.load_model("./models/CNN_Augmented_100_model.h5", compile=False)
# TODO : train generalized model by inverting colour of half of the digits.
def contrast_stretching(image):
# Calculate the minimum and maximum pixel values in the image
min_val = np.min(image)
max_val = np.max(image)
# Apply contrast stretching
stretched_image = ((image - min_val) / (max_val - min_val)) * 255 # min-max
stretched_image = stretched_image.astype(np.uint8) # Convert to uint8
return stretched_image
def preprocess_image(image):
# Convert the image to grayscale
gray_image = image.convert("L")
# TODO: Background correction (Uniformity for pen and paper images). Might be homomorphic filtering
# Calculate average intensity
pixels = list(gray_image.getdata())
avg_intensity = sum(pixels) / len(pixels)
# Determine if background is light or dark based on average intensity
if avg_intensity > 100: # Light bg
# Invert the colors
gray_image = Image.eval(gray_image, lambda x: 255 - x)
# Apply histogram equalization
equalized_image = ImageOps.equalize(gray_image)
# Resize image to match model input size
resized_image = equalized_image.resize((28, 28))
# Apply contrast stretching
return contrast_stretching(resized_image)
# Define a function for model inference
# @tf.function
def predict(image):
# Open and preprocess the image
img = Image.open(image)
preprocessed_image = preprocess_image(img)
# st.image(preprocessed_image, width=500, caption="pre")
# Convert image to NumPy array
img_array = np.array(preprocessed_image)
img_array = np.expand_dims(img_array, axis=-1) # Add batch dimension
# Make predictions using the loaded model
prediction = model.predict(np.expand_dims(img_array, axis=0))
return prediction
# Streamlit app code
st.set_page_config(
page_title="Classification of Handwritten Digits",
layout="centered",
initial_sidebar_state="expanded",
menu_items={
"Get Help": "https://github.com/Subhranil2004/digits-in-ink#live-demo",
"Report a bug": "https://github.com/Subhranil2004/handwritten-digit-classification/issues",
},
)
# Sidebar
with st.sidebar:
st.image(
"./images/image.jpg",
use_column_width=True,
output_format="JPEG",
)
st.sidebar.subheader(":blue[[Please use a desktop for the best experience.]]")
st.sidebar.title("Classification of Handwritten Digits [0 - 9]")
st.sidebar.write(
"The model is trained on the ***MNIST dataset*** and uses Convolutional Neural Network with Data augmentation. It has an exceptional accuracy rate of 99.45% on MNIST test dataset."
)
st.sidebar.write(
"There is always a scope for improvement and I would appreciate suggestions and/or constructive criticisms."
)
st.sidebar.link_button("GitHub", "https://github.com/Subhranil2004")
st.markdown(
f"""
<style>
.sidebar {{
width: 500px;
}}
</style>
""",
unsafe_allow_html=True,
)
# Main content
st.title("Classification of Handwritten Digits")
uploaded_file = st.file_uploader(
"Choose an image...",
type=["jpg", "png", "bmp", "tiff"],
)
if uploaded_file is not None:
# Display the uploaded image with border
st.image(
uploaded_file,
caption="Uploaded Image",
width=300,
clamp=True,
# output_format="JPEG",
)
# Perform prediction
if st.button("Predict"):
result = predict(uploaded_file)
progress_bar = st.progress(0)
status_text = st.empty()
for i in range(100):
progress_bar.progress(i + 1)
status_text.text(f"Progress: {i}%")
time.sleep(0.00)
status_text.text("Done!")
# Display the prediction result
max_index = np.argmax(result)
st.write("Predicted Digit : ", max_index, " 🎉")
del_result = np.delete(result, max_index)
res2 = np.argmax(del_result)
if res2 >= max_index:
res2 = res2 + 1
expander = st.expander(
":orange[If you aren't satisfied with the result, check the prediction probabilities below ⬇⚠️]"
)
expander.write(f"Second most probable prediction: {res2}")
expander.write(result) # , res2, result
# expander.write(result)
expander = st.expander("Some real life images to try with...", expanded=True)
expander.write("Just drag-and-drop your chosen image above ")
expander.image(
[
"./Real_Life_Images/seven4.png",
"./Real_Life_Images/six2.png",
"./Real_Life_Images/two1.png",
"./Real_Life_Images/nine1.png",
"./Real_Life_Images/zero1.png",
"./Real_Life_Images/seven3.png",
"./Real_Life_Images/eight3.png",
"./Real_Life_Images/one1.png",
"./Real_Life_Images/four6.png",
"./Real_Life_Images/five6.png",
"./Real_Life_Images/zero2.png",
"./Real_Life_Images/nine4.png",
"./Real_Life_Images/three6.png",
"./Real_Life_Images/four7.png",
],
width=95,
)
expander.write(
"All images might not give the desired result as the *1st* prediction due to low contrast. Check the probability scores in such cases."
)
expander = st.expander("View Model Training and Validation Results")
expander.write("Confusion Matrix: ")
expander.image("./images/CNN_ConfusionMatrix.png", use_column_width=True)
expander.write("Graphs: ")
expander.image("./images/CNN_Graphs.png", use_column_width=True)
expander = st.expander("If you are getting inaccurate results, follow these steps:")
expander.markdown(
"""
1. Use OneNote/MS Paint solid-colour background
2. Upload small to medium size images (ideally under (600 x 600))
3. If large sized images are uploaded thicken its stroke
4. Make sure digit occupies the maximum part of the image
`Pen and paper images` aren't compatible with the classifier till now due to non-uniformity of background colour (illumination). Actively working on that !
:red[Note: ] :orange[If you are getting incorrect predictions even after following the above steps, kindly drop in the image in `Report a bug` option ,i.e, by creating an issue in GitHub. That will help in improving the app further 🙂]
"""
)
# Footer
st.write("\n\n\n")
st.markdown("---")
st.markdown(
f"""Drop in any discrepancies or give suggestions in `Report a bug` option within the `⋮` menu"""
)
st.markdown(
f"""<div style="text-align: right"> Developed by Subhranil Nandy </div>""",
unsafe_allow_html=True,
)