forked from h2oai/h2o-3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexample.scala
55 lines (47 loc) · 1.9 KB
/
example.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
//
// Load data
//
// Shortcut for local files
val air = new DataFrame(H2OFiles.get("allyears_tiny.csv"))
// Generic form to any datasource in form ''<schema>://location' :
// URI dataSourceURI = new java.net.URI("hdfs://mr-0xd6")
//
// val air = new DataFrame(d)
//
// Note: we should have h2o-specific schema for cluster data
//
// Generate a vector with uniform distribution, length of vector
// and vector group are derived from given vector/frame
val airS = air ++ ('S, Vec.runif(air)) // Append vector at the end of frame to be usable in M/R tasks
//
// Filtering and slicing
//
// Note: this is only idea based on Scalding+Shalala+h2o-dev-scala API
//
// Frame Oper ColSelect (and output spec) FUNC
// | | | |
val airTrain = airS filter ('S) { (s:Double) => s <= 0.8}
val airValid = airS filter ('S) { (s:Double) => s > 0.8 && s <= 0.9}
val airTest = airS filter ('S) { (s:Double) => s > 0.9 }
// Create Parameters for run
val gbmParams = new GBMParameters()
// Column selector
gbmParams._train = airTrain('Origin, 'Dest, 'Distance, 'UniqueCarrier, 'Month, 'DayofMonth, 'DayOfWeek)
gbmParams._valid = airValid // Do not need to select columns since algo will filter right one
gbmParams._response_column = 'IsDepDelayed
gbmParams._distribution = Distributions.MULTINOMIAL // enum
gbmParams._interaction_depth = 3
gbmParams._shrinkage = 0.01
gbmParams._importance = true
gbmParams._cv = new CVParams(nfold=3, seed=42)
// Create builder
val gbm = new GBM(gbmParams)
// Invoke builder and get a model
val gbmModel = gbm.fit
//
// Make a prediction
// - use API call and select the right column with prediction
//val rawAirData = sc.textFile(SparkFiles.get("allyears_tiny.csv"), /* # partitions */ 3)
// Produce RDD[Flight], Flight is POJO
//val airRDD /*:RDD[Flight] */ = rawAirData....
val pred = gbmModel.score(airRDD)