-
Notifications
You must be signed in to change notification settings - Fork 1
/
make_xyc_dataset.py
50 lines (36 loc) · 1.16 KB
/
make_xyc_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue May 23 2023
"""
import os
import numpy as np
from tqdm import tqdm
from itertools import product
def circle(img, center, radius, color):
m,n = img.shape
center = np.array(center, dtype = np.float32)
x = np.arange(0, m)
coords = product(x, x)
coords = np.array(list(coords), dtype = np.float32)
in_circle = np.where(np.linalg.norm(coords-center, axis = -1) < radius)[0]
img[coords[in_circle].astype(np.uint8)[:,0], coords[in_circle].astype(np.uint8)[:,1]] = color
return img
datasets = []
labels = []
for i in tqdm(range(15, 84-15-1)):
for j in range(15, 84-15-1):
for c in [0.2, 0.4, 0.6, 0.8, 1.0]:
template = np.zeros((84,84), dtype = np.float32)
datasets.append(circle(template, (j, i), 15, c))
labels.append(np.array([j, i, c]))
n_samples = len(datasets)
print(n_samples)
datasets = np.stack(datasets)
labels = np.stack(labels)
dataset_folder_name = 'datasets'
try:
os.mkdir(dataset_folder_name)
except OSError:
pass
np.savez(os.path.join(dataset_folder_name, 'xyc.npz'), imgs = datasets, labs = labels)