-
Notifications
You must be signed in to change notification settings - Fork 5
/
mean_shift.go
145 lines (120 loc) · 4.59 KB
/
mean_shift.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package mlpack
/*
#cgo CFLAGS: -I./capi -Wall
#cgo LDFLAGS: -L. -lmlpack_go_mean_shift
#include <capi/mean_shift.h>
#include <stdlib.h>
*/
import "C"
import "gonum.org/v1/gonum/mat"
type MeanShiftOptionalParam struct {
ForceConvergence bool
InPlace bool
LabelsOnly bool
MaxIterations int
Radius float64
Verbose bool
}
func MeanShiftOptions() *MeanShiftOptionalParam {
return &MeanShiftOptionalParam{
ForceConvergence: false,
InPlace: false,
LabelsOnly: false,
MaxIterations: 1000,
Radius: 0,
Verbose: false,
}
}
/*
This program performs mean shift clustering on the given dataset, storing the
learned cluster assignments either as a column of labels in the input dataset
or separately.
The input dataset should be specified with the "Input" parameter, and the
radius used for search can be specified with the "Radius" parameter. The
maximum number of iterations before algorithm termination is controlled with
the "MaxIterations" parameter.
The output labels may be saved with the "Output" output parameter and the
centroids of each cluster may be saved with the "Centroid" output parameter.
For example, to run mean shift clustering on the dataset data and store the
centroids to centroids, the following command may be used:
// Initialize optional parameters for MeanShift().
param := mlpack.MeanShiftOptions()
centroids, _ := mlpack.MeanShift(data, param)
Input parameters:
- input (mat.Dense): Input dataset to perform clustering on.
- ForceConvergence (bool): If specified, the mean shift algorithm will
continue running regardless of max_iterations until the clusters
converge.
- InPlace (bool): If specified, a column containing the learned cluster
assignments will be added to the input dataset file. In this case,
--output_file is overridden. (Do not use with Python.)
- LabelsOnly (bool): If specified, only the output labels will be
written to the file specified by --output_file.
- MaxIterations (int): Maximum number of iterations before mean shift
terminates. Default value 1000.
- Radius (float64): If the distance between two centroids is less than
the given radius, one will be removed. A radius of 0 or less means an
estimate will be calculated and used for the radius. Default value 0.
- Verbose (bool): Display informational messages and the full list of
parameters and timers at the end of execution.
Output parameters:
- centroid (mat.Dense): If specified, the centroids of each cluster
will be written to the given matrix.
- output (mat.Dense): Matrix to write output labels or labeled data
to.
*/
func MeanShift(input *mat.Dense, param *MeanShiftOptionalParam) (*mat.Dense, *mat.Dense) {
params := getParams("mean_shift")
timers := getTimers()
disableBacktrace()
disableVerbose()
// Detect if the parameter was passed; set if so.
gonumToArmaMat(params, "input", input, false)
setPassed(params, "input")
// Detect if the parameter was passed; set if so.
if param.ForceConvergence != false {
setParamBool(params, "force_convergence", param.ForceConvergence)
setPassed(params, "force_convergence")
}
// Detect if the parameter was passed; set if so.
if param.InPlace != false {
setParamBool(params, "in_place", param.InPlace)
setPassed(params, "in_place")
}
// Detect if the parameter was passed; set if so.
if param.LabelsOnly != false {
setParamBool(params, "labels_only", param.LabelsOnly)
setPassed(params, "labels_only")
}
// Detect if the parameter was passed; set if so.
if param.MaxIterations != 1000 {
setParamInt(params, "max_iterations", param.MaxIterations)
setPassed(params, "max_iterations")
}
// Detect if the parameter was passed; set if so.
if param.Radius != 0 {
setParamDouble(params, "radius", param.Radius)
setPassed(params, "radius")
}
// Detect if the parameter was passed; set if so.
if param.Verbose != false {
setParamBool(params, "verbose", param.Verbose)
setPassed(params, "verbose")
enableVerbose()
}
// Mark all output options as passed.
setPassed(params, "centroid")
setPassed(params, "output")
// Call the mlpack program.
C.mlpackMeanShift(params.mem, timers.mem)
// Initialize result variable and get output.
var centroidPtr mlpackArma
centroid := centroidPtr.armaToGonumMat(params, "centroid")
var outputPtr mlpackArma
output := outputPtr.armaToGonumMat(params, "output")
// Clean memory.
cleanParams(params)
cleanTimers(timers)
// Return output(s).
return centroid, output
}