forked from AmusementClub/vs-dfttest2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dfttest2.py
798 lines (671 loc) · 24.8 KB
/
dfttest2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
__version__ = "0.3.3"
from dataclasses import dataclass
import math
from string import Template
import typing
import vapoursynth as vs
from vapoursynth import core
__all__ = ["DFTTest", "DFTTest2", "Backend"]
class Backend:
@dataclass(frozen=False)
class cuFFT:
device_id: int = 0
in_place: bool = True
@dataclass(frozen=False)
class NVRTC:
device_id: int = 0
num_streams: int = 1
@dataclass(frozen=False)
class CPU:
opt: int = 0
@dataclass(frozen=False)
class GCC:
pass
backendT = typing.Union[Backend.cuFFT, Backend.NVRTC, Backend.CPU, Backend.GCC]
def init_backend(backend: backendT) -> backendT:
if backend is Backend.cuFFT: # type: ignore
backend = Backend.cuFFT()
elif backend is Backend.NVRTC: # type: ignore
backend = Backend.NVRTC()
elif backend is Backend.CPU: # type: ignore
backend = Backend.CPU()
elif backend is Backend.GCC: # type: ignore
backend = Backend.GCC()
return backend
# https://github.com/HomeOfVapourSynthEvolution/VapourSynth-DFTTest/blob/
# bc5e0186a7f309556f20a8e9502f2238e39179b8/DFTTest/DFTTest.cpp#L518
def normalize(
window: typing.Sequence[float],
size: int,
step: int
) -> typing.List[float]:
nw = [0.0] * size
for q in range(size):
for h in range(q, -1, -step):
nw[q] += window[h] ** 2
for h in range(q + step, size, step):
nw[q] += window[h] ** 2
return [window[q] / math.sqrt(nw[q]) for q in range(size)]
# https://github.com/HomeOfVapourSynthEvolution/VapourSynth-DFTTest/blob/
# bc5e0186a7f309556f20a8e9502f2238e39179b8/DFTTest/DFTTest.cpp#L462
def get_window_value(location: float, size: int, mode: int, beta: float) -> float:
temp = math.pi * location / size
if mode == 0: # hanning
return 0.5 * (1 - math.cos(2 * temp))
elif mode == 1: # hamming
return 0.53836 - 0.46164 * math.cos(2 * temp)
elif mode == 2: # blackman
return 0.42 - 0.5 * math.cos(2 * temp) + 0.08 * math.cos(4 * temp)
elif mode == 3: # 4 term blackman-harris
return (
0.35875
- 0.48829 * math.cos(2 * temp)
+ 0.14128 * math.cos(4 * temp)
- 0.01168 * math.cos(6 * temp)
)
elif mode == 4: # kaiser-bessel
def i0(p: float) -> float:
p /= 2
n = t = d = 1.0
k = 1
while True:
n *= p
d *= k
v = n / d
t += v * v
k += 1
if k >= 15 or v <= 1e-8:
break
return t
v = 2 * location / size - 1
return i0(math.pi * beta * math.sqrt(1 - v * v)) / i0(math.pi * beta)
elif mode == 5: # 7 term blackman-harris
return (
0.27105140069342415
- 0.433297939234486060 * math.cos(2 * temp)
+ 0.218122999543110620 * math.cos(4 * temp)
- 0.065925446388030898 * math.cos(6 * temp)
+ 0.010811742098372268 * math.cos(8 * temp)
- 7.7658482522509342e-4 * math.cos(10 * temp)
+ 1.3887217350903198e-5 * math.cos(12 * temp)
)
elif mode == 6: # flat top
return (
0.2810639
- 0.5208972 * math.cos(2 * temp)
+ 0.1980399 * math.cos(4 * temp)
)
elif mode == 7: # rectangular
return 1.0
elif mode == 8: # Bartlett
return 1 - 2 * abs(location - size / 2) / size
elif mode == 9: # bartlett-hann
return 0.62 - 0.48 * (location / size - 0.5) - 0.38 * math.cos(2 * temp)
elif mode == 10: # nuttall
return (
0.355768
- 0.487396 * math.cos(2 * temp)
+ 0.144232 * math.cos(4 * temp)
- 0.012604 * math.cos(6 * temp)
)
elif mode == 11: # blackman-nuttall
return (
0.3635819
- 0.4891775 * math.cos(2 * temp)
+ 0.1365995 * math.cos(4 * temp)
- 0.0106411 * math.cos(6 * temp)
)
else:
raise ValueError("unknown window")
# https://github.com/HomeOfVapourSynthEvolution/VapourSynth-DFTTest/blob/
# bc5e0186a7f309556f20a8e9502f2238e39179b8/DFTTest/DFTTest.cpp#L461
def get_window(
radius: int,
block_size: int,
block_step: int,
spatial_window_mode: int,
spatial_beta: float,
temporal_window_mode: int,
temporal_beta: float
) -> typing.List[float]:
temporal_window = [
get_window_value(
location = i + 0.5,
size = 2 * radius + 1,
mode = temporal_window_mode,
beta = temporal_beta
) for i in range(2 * radius + 1)
]
spatial_window = [
get_window_value(
location = i + 0.5,
size = block_size,
mode = spatial_window_mode,
beta = spatial_beta
) for i in range(block_size)
]
spatial_window = normalize(
window=spatial_window,
size=block_size,
step=block_step
)
window = []
for t_val in temporal_window:
for s_val1 in spatial_window:
for s_val2 in spatial_window:
value = t_val * s_val1 * s_val2
# normalize for unnormalized FFT implementation
value /= math.sqrt(2 * radius + 1) * block_size
window.append(value)
return window
# https://github.com/HomeOfVapourSynthEvolution/VapourSynth-DFTTest/blob/
# bc5e0186a7f309556f20a8e9502f2238e39179b8/DFTTest/DFTTest.cpp#L581
def get_location(
position: float,
length: int
) -> float:
if length == 1:
return 0.0
elif position > length // 2:
return (length - position) / (length // 2)
else:
return position / (length // 2)
# https://github.com/HomeOfVapourSynthEvolution/VapourSynth-DFTTest/blob/
# bc5e0186a7f309556f20a8e9502f2238e39179b8/DFTTest/DFTTest.cpp#L581
def get_sigma(
position: float,
length: int,
func: typing.Callable[[float], float]
) -> float:
if length == 1:
return 1.0
else:
return func(get_location(position, length))
def DFTTest2(
clip: vs.VideoNode,
ftype: typing.Literal[0, 1, 2, 3, 4] = 0,
sigma: typing.Union[float, typing.Sequence[typing.Callable[[float], float]]] = 8.0,
sigma2: float = 8.0,
pmin: float = 0.0,
pmax: float = 500.0,
sbsize: int = 16,
sosize: int = 12,
tbsize: int = 3,
swin: typing.Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] = 0,
twin: typing.Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] = 7,
sbeta: float = 2.5,
tbeta: float = 2.5,
zmean: bool = True,
f0beta: float = 1.0,
ssystem: typing.Literal[0, 1] = 0,
planes: typing.Optional[typing.Union[int, typing.Sequence[int]]] = None,
backend: backendT = Backend.cuFFT()
) -> vs.VideoNode:
""" this interface is not stable """
# translate parameters
if ftype == 0:
if abs(f0beta - 1) < 0.00005:
filter_type = 0
elif abs(f0beta - 0.5) < 0.0005:
filter_type = 6
else:
filter_type = 5
else:
filter_type = ftype
radius = (tbsize - 1) // 2
block_size = sbsize
block_step = sbsize - sosize
spatial_window_mode = swin
temporal_window_mode = twin
spatial_beta = sbeta
temporal_beta = tbeta
zero_mean = zmean
backend = init_backend(backend)
if isinstance(backend, (Backend.CPU, Backend.NVRTC, Backend.GCC)):
if radius not in range(4):
raise ValueError("invalid radius (tbsize)")
if block_size != 16:
raise ValueError("invalid block_size (sbsize)")
# compute constants
try:
sigma_scalar = float(sigma) # type: ignore
sigma_is_scalar = True
except:
# compute sigma_array
sigma_is_scalar = False
sigma_funcs = typing.cast(typing.Sequence[typing.Callable[[float], float]], sigma)
if callable(sigma_funcs):
sigma_funcs = [sigma_funcs]
else:
sigma_funcs = list(sigma_funcs)
sigma_funcs.extend([sigma_funcs[-1]] * 3)
sigma_func_x, sigma_func_y, sigma_func_t = sigma_funcs[:3]
sigma_array = []
if ssystem == 0:
for t in range(2 * radius + 1):
sigma_t = get_sigma(position=t, length=2*radius+1, func=sigma_func_t)
for y in range(block_size):
sigma_y = get_sigma(position=y, length=block_size, func=sigma_func_y)
for x in range(block_size // 2 + 1):
sigma_x = get_sigma(position=x, length=block_size, func=sigma_func_x)
sigma = sigma_t * sigma_y * sigma_x
sigma_array.append(sigma)
else:
for t in range(2 * radius + 1):
loc_t = get_location(position=t, length=2*radius+1)
for y in range(block_size):
loc_y = get_location(position=y, length=block_size)
for x in range(block_size // 2 + 1):
loc_x = get_location(position=x, length=block_size)
ndim = 3 if radius > 0 else 2
location = math.sqrt((loc_t * loc_t + loc_y * loc_y + loc_x * loc_x) / ndim)
sigma = sigma_func_t(location)
sigma_array.append(sigma)
window = get_window(
radius=radius,
block_size=block_size,
block_step=block_step,
spatial_window_mode=spatial_window_mode,
temporal_window_mode=temporal_window_mode,
spatial_beta=spatial_beta,
temporal_beta=temporal_beta
)
wscale = math.fsum(w * w for w in window)
if ftype < 2:
if sigma_is_scalar:
sigma_scalar *= wscale
else:
sigma_array = [s * wscale for s in sigma_array]
sigma2 *= wscale
pmin *= wscale
pmax *= wscale
if isinstance(backend, Backend.cuFFT):
rdft = core.dfttest2_cuda.RDFT
elif isinstance(backend, Backend.NVRTC):
rdft = core.dfttest2_nvrtc.RDFT
elif isinstance(backend, Backend.CPU):
rdft = core.dfttest2_cpu.RDFT
elif isinstance(backend, Backend.GCC):
rdft = core.dfttest2_gcc.RDFT
else:
raise TypeError("unknown backend")
if radius == 0:
window_freq = rdft(
data=[w * 255 for w in window],
shape=(block_size, block_size)
)
else:
window_freq = rdft(
data=[w * 255 for w in window],
shape=(2 * radius + 1, block_size, block_size)
)
if isinstance(backend, Backend.CPU):
return core.dfttest2_cpu.DFTTest(
clip,
window=window,
sigma=[sigma_scalar] * (2 * radius + 1) * block_size * (block_size // 2 + 1) if sigma_is_scalar else sigma_array,
sigma2=sigma2,
pmin=pmin,
pmax=pmax,
radius=radius,
block_size=block_size,
block_step=block_step,
planes=planes,
filter_type=filter_type,
window_freq=window_freq,
opt=backend.opt
)
elif isinstance(backend, Backend.GCC):
return core.dfttest2_gcc.DFTTest(
clip,
window=window,
sigma=[sigma_scalar] * (2 * radius + 1) * block_size * (block_size // 2 + 1) if sigma_is_scalar else sigma_array,
sigma2=sigma2,
pmin=pmin,
pmax=pmax,
radius=radius,
block_size=block_size,
block_step=block_step,
planes=planes,
filter_type=filter_type,
window_freq=window_freq
)
if isinstance(backend, Backend.cuFFT):
to_single = core.dfttest2_cuda.ToSingle
elif isinstance(backend, Backend.NVRTC):
to_single = core.dfttest2_nvrtc.ToSingle
else:
raise TypeError("unknown backend")
kernel = Template(
"""
#define FILTER_TYPE ${filter_type}
#define ZERO_MEAN ${zero_mean}
#define SIGMA_IS_SCALAR ${sigma_is_scalar}
#if ZERO_MEAN
__device__ static const float window_freq[] { ${window_freq} };
#endif // ZERO_MEAN
__device__ static const float window[] { ${window} };
__device__
static void filter(float2 & value, int x, int y, int t) {
#if SIGMA_IS_SCALAR
float sigma = static_cast<float>(${sigma});
#else // SIGMA_IS_SCALAR
__device__ static const float sigma_array[] { ${sigma} };
float sigma = sigma_array[(t * BLOCK_SIZE + y) * (BLOCK_SIZE / 2 + 1) + x];
#endif // SIGMA_IS_SCALAR
[[maybe_unused]] float sigma2 = static_cast<float>(${sigma2});
[[maybe_unused]] float pmin = static_cast<float>(${pmin});
[[maybe_unused]] float pmax = static_cast<float>(${pmax});
[[maybe_unused]] float multiplier {};
#if FILTER_TYPE == 2
value.x *= sigma;
value.y *= sigma;
return ;
#endif
float psd = value.x * value.x + value.y * value.y;
#if FILTER_TYPE == 1
if (psd < sigma) {
value.x = 0.0f;
value.y = 0.0f;
}
return ;
#elif FILTER_TYPE == 0
multiplier = fmaxf((psd - sigma) / (psd + 1e-15f), 0.0f);
#elif FILTER_TYPE == 3
if (psd >= pmin && psd <= pmax) {
multiplier = sigma;
} else {
multiplier = sigma2;
}
#elif FILTER_TYPE == 4
multiplier = sigma * sqrtf(psd * (pmax / ((psd + pmin) * (psd + pmax) + 1e-15f)));
#elif FILTER_TYPE == 5
multiplier = powf(fmaxf((psd - sigma) / (psd + 1e-15f), 0.0f), pmin);
#else
multiplier = sqrtf(fmaxf((psd - sigma) / (psd + 1e-15f), 0.0f));
#endif
value.x *= multiplier;
value.y *= multiplier;
}
"""
).substitute(
sigma_is_scalar=int(sigma_is_scalar),
sigma=(
to_single(sigma_scalar)
if sigma_is_scalar
else ','.join(str(to_single(x)) for x in sigma_array)
),
sigma2=to_single(sigma2),
pmin=to_single(pmin),
pmax=to_single(pmax),
filter_type=int(filter_type),
window_freq=','.join(str(to_single(x)) for x in window_freq),
zero_mean=int(zero_mean),
window=','.join(str(to_single(x)) for x in window),
)
if isinstance(backend, Backend.cuFFT):
return core.dfttest2_cuda.DFTTest(
clip,
kernel=kernel,
radius=radius,
block_size=block_size,
block_step=block_step,
planes=planes,
in_place=backend.in_place,
device_id=backend.device_id
)
elif isinstance(backend, Backend.NVRTC):
return core.dfttest2_nvrtc.DFTTest(
clip,
kernel=kernel,
radius=radius,
block_size=block_size,
block_step=block_step,
planes=planes,
in_place=False,
device_id=backend.device_id,
num_streams=backend.num_streams
)
else:
raise TypeError("unknown backend")
def select_backend(
backend: typing.Optional[backendT],
sbsize: int,
tbsize: int
) -> backendT:
if backend is not None:
return backend
if sbsize == 16 and tbsize in [1, 3, 5, 7]:
if hasattr(core, "dfttest2_nvrtc"):
return Backend.NVRTC()
elif hasattr(core, "dfttest2_cuda"):
return Backend.cuFFT()
elif hasattr(core, "dfttest2_cpu"):
return Backend.CPU()
else:
return Backend.GCC()
else:
return Backend.cuFFT()
FREQ = float
SIGMA = float
def flatten(
data: typing.Optional[typing.Union[
typing.Sequence[typing.Tuple[FREQ, SIGMA]],
typing.Sequence[float]
]]
) -> typing.Optional[typing.List[float]]:
import itertools as it
import numbers
if data is None:
return None
elif isinstance(data[0], numbers.Real):
return data
else:
data = typing.cast(typing.Sequence[typing.Tuple[FREQ, SIGMA]], data)
return list(it.chain.from_iterable(data))
def to_func(
data: typing.Optional[typing.Sequence[float]],
norm: typing.Callable[[float], float],
sigma: float
) -> typing.Callable[[float], float]:
if data is None:
return lambda _: norm(sigma)
locations = data[::2]
sigmas = data[1::2]
packs = list(zip(locations, sigmas))
packs = sorted(packs, key=lambda group: group[0])
def func(x: float) -> float:
length = len(packs)
for i in range(length - 1):
if x <= packs[i + 1][0]:
weight = (x - packs[i][0]) / (packs[i + 1][0] - packs[i][0])
return (1 - weight) * norm(packs[i][1]) + weight * norm(packs[i + 1][1])
raise ValueError()
return func
def DFTTest(
clip: vs.VideoNode,
ftype: typing.Literal[0, 1, 2, 3, 4] = 0,
sigma: float = 8.0,
sigma2: float = 8.0,
pmin: float = 0.0,
pmax: float = 500.0,
sbsize: int = 16,
smode: typing.Literal[0, 1] = 1,
sosize: int = 12,
tbsize: int = 3,
# tmode=0, tosize=0
swin: typing.Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] = 0,
twin: typing.Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] = 7,
sbeta: float = 2.5,
tbeta: float = 2.5,
zmean: bool = True,
f0beta: float = 1.0,
nlocation: typing.Optional[typing.Sequence[int]] = None,
alpha: typing.Optional[float] = None,
slocation: typing.Optional[typing.Union[
typing.Sequence[typing.Tuple[FREQ, SIGMA]],
typing.Sequence[float]
]] = None,
ssx: typing.Optional[typing.Union[
typing.Sequence[typing.Tuple[FREQ, SIGMA]],
typing.Sequence[float]
]] = None,
ssy: typing.Optional[typing.Union[
typing.Sequence[typing.Tuple[FREQ, SIGMA]],
typing.Sequence[float]
]] = None,
sst: typing.Optional[typing.Union[
typing.Sequence[typing.Tuple[FREQ, SIGMA]],
typing.Sequence[float]
]] = None,
ssystem: typing.Literal[0, 1] = 0,
planes: typing.Optional[typing.Union[int, typing.Sequence[int]]] = None,
backend: typing.Optional[backendT] = None
) -> vs.VideoNode:
""" 2D/3D frequency domain denoiser
The interface is compatible with core.dfttest.DFTTest by HolyWu.
Args:
clip: Clip to process.
Any format with either integer sample type of 8-16 bit depth
or float sample type of 32 bit depth is supported.
ftype: Controls the filter type.
Possible settings are:
0: generalized wiener filter
mult = max((psd - sigma) / psd, 0) ^ f0beta
1: hard threshold
mult = psd < sigma ? 0.0 : 1.0
2: multiplier
mult = sigma
3: multiplier switched based on psd value
mult = (psd >= pmin && psd <= pmax) ? sigma : sigma2
4: multiplier modified based on psd value and range
mult = sigma * sqrt((psd * pmax) / ((psd + pmin) * (psd + pmax)))
The real and imaginary parts of each complex dft coefficient are multiplied
by the corresponding 'mult' value.
** psd = magnitude squared = real*real + imag*imag
sigma, sigma2: Value of sigma and sigma2.
If using the slocation parameter then the sigma parameter is ignored.
pmin, pmax: Used as described in the ftype parameter description.
sbsize: Sets the length of the sides of the spatial window.
Must be 1 or greater. Must be odd if using smode=0.
smode: Sets the mode for spatial operation.
Currently only tmode=1 is implemented.
sosize: Sets the spatial overlap amount.
Must be in the range 0 to sbsize-1 (inclusive).
If sosize is greater than sbsize>>1, then sbsize%(sbsize-sosize) must equal 0.
In other words, overlap greater than 50% requires that sbsize-sosize be a divisor of sbsize.
tbsize: Sets the length of the temporal dimension (i.e. number of frames).
Must be at least 1. Must be odd if using tmode=0.
tmode: Sets the mode for temporal operation.
Currently only tmode=0 is implemented.
tosize: Sets the temporal overlap amount.
Must be in the range 0 to tbsize-1 (inclusive).
If tosize is greater than tbsize>>1, then tbsize%(tbsize-tosize) must equal 0.
In other words, overlap greater than 50% requires that tbsize-tosize be a divisor of tbsize.
swin, twin: Sets the type of analysis/synthesis window to be used for spatial (swin) and
temporal (twin) processing. Possible settings:
0: hanning
1: hamming
2: blackman
3: 4 term blackman-harris
4: kaiser-bessel
5: 7 term blackman-harris
6: flat top
7: rectangular
8: Bartlett
9: Bartlett-Hann
10: Nuttall
11: Blackman-Nuttall
sbeta,tbeta: Sets the beta value for kaiser-bessel window type.
sbeta goes with swin, tbeta goes with twin.
Not used unless the corresponding window value is set to 4.
zmean: Controls whether the window mean is subtracted out (zero'd)
prior to filtering in the frequency domain.
f0beta: Power term in ftype=0.
nlocation: Currently not implemented.
slocation/ssx/ssy/sst: Used to specify functions of sigma based on frequency.
Check the original documentation for details.
Note that in current implementation,
"slocation = [(0.0, 1.0), (1.0, 10.0)]"
is equivalent to
"slocation = [0.0, 1.0, 1.0, 10.0]"
ssystem: Method of sigma computation.
Check the original documentation for details.
planes: Sets which planes will be processed.
Any unprocessed planes will be simply copied.
backend: Backend implementation to use.
All available backends can be found in the dfttest2.Backend "namespace":
dfttest2.Backend.{CPU, cuFFT, NVRTC, GCC}
The CPU, NVRTC and GCC backends require sbsize=16.
The cuFFT and NVRTC backend require a CUDA-enabled system.
Speed: NVRTC >> cuFFT > CPU == GCC
"""
if (
not isinstance(clip, vs.VideoNode) or
clip.width == 0 or
clip.height == 0 or
clip.format is None or
(clip.format.sample_type == vs.INTEGER and clip.format.bits_per_sample > 16) or
(clip.format.sample_type == vs.FLOAT and clip.format.bits_per_sample != 32)
):
raise ValueError("only constant format 8-16 bit integer and 32 bit float input supported")
if ftype < 0 or ftype > 4:
raise ValueError("ftype must be 0, 1, 2, 3, or 4")
if sbsize < 1:
raise ValueError("sbsize must be greater than or equal to 1")
if smode != 1:
raise ValueError('"smode" must be 1')
if sosize > sbsize // 2 and (sbsize % (sbsize - sosize) != 0):
raise ValueError("spatial overlap greater than 50% requires that sbsize-sosize is a divisor of sbsize")
if tbsize < 1:
raise ValueError('"tbsize" must be at least 1')
if swin < 0 or swin > 11:
raise ValueError("swin must be between 0 and 11 (inclusive)")
if twin < 0 or twin > 11:
raise ValueError("twin must be between 0 and 11 (inclusive)")
if nlocation is not None:
raise ValueError('"nlocation" must be None')
if slocation and len(slocation) % 2 != 0:
raise ValueError("number of elements in slocation must be a multiple of 2")
if ssx and len(ssx) % 2 != 0:
raise ValueError("number of elements in ssx must be a multiple of 2")
if ssy and len(ssy) % 2 != 0:
raise ValueError("number of elements in ssy must be a multiple of 2")
if sst and len(sst) % 2 != 0:
raise ValueError("number of elements in sst must be a multiple of 2")
if ssystem < 0 or ssystem > 1:
raise ValueError("ssystem must be 0 or 1")
def norm(x: float) -> float:
if slocation is not None and ssystem == 1:
return x
elif tbsize == 1:
return math.sqrt(x)
else:
return x ** (1 / 3)
_sigma: typing.Union[float, typing.Sequence[typing.Callable[[float], float]]]
if slocation is not None:
_sigma = [to_func(flatten(slocation), norm, sigma)] * 3
elif any(ss is not None for ss in (ssx, ssy, sst)):
_sigma = [to_func(flatten(ss), norm, sigma) for ss in (ssx, ssy, sst)]
else:
_sigma = sigma
return DFTTest2(
clip = clip,
ftype = ftype,
sigma = _sigma,
sigma2 = sigma2,
pmin = pmin,
pmax = pmax,
sbsize = sbsize,
sosize = sosize,
tbsize = tbsize,
swin = swin,
twin = twin,
sbeta = sbeta,
tbeta = tbeta,
zmean = zmean,
f0beta = f0beta,
ssystem = ssystem,
planes = planes,
backend = select_backend(backend, sbsize, tbsize)
)