Skip to content

Commit

Permalink
v1.2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
jimichan committed Sep 11, 2018
1 parent 7c3dd8c commit ba942a4
Show file tree
Hide file tree
Showing 14 changed files with 660 additions and 583 deletions.
7 changes: 4 additions & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ plugins {


description = 'FastText的java版本实现,兼容facebook发布的原生预训练模型。'
version = "1.1.6-SNAPSHOT"
version = "1.2.0"
//.BUILD-SNAPSHOT

group = "com.mayabot"
Expand All @@ -33,7 +33,7 @@ dependencies {
compile "org.jetbrains.kotlin:kotlin-stdlib-jdk8"

compile group: 'com.carrotsearch', name: 'hppc', version: '0.7.3'
compile 'com.mayabot:maya-simple-blas:1.0.0'
compile 'com.mayabot:maya-simple-blas:1.1.0'
compile group: 'com.google.guava', name: 'guava', version: "19.0"

testCompile 'junit:junit:4.12'
Expand Down Expand Up @@ -102,7 +102,7 @@ publishing {
}

repositories {
if(oss_user){
if(project.hasProperty("oss_user")){
maven {
name 'OssPublic'
if (project.version.endsWith('-SNAPSHOT')) {
Expand All @@ -117,6 +117,7 @@ publishing {
}
}


if(project.hasProperty("maya_pri_user")){
maven {
name 'MayaPrivate'
Expand Down
26 changes: 23 additions & 3 deletions src/example/java/com/mayabot/mynlp/fasttext/AgnewsTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,32 @@ import java.io.File
fun main(args: Array<String>) {
val file = File("src/example/resources/ag.train")

val train = FastText.train(file, ModelName.sup)
val model = FastText.train(file, ModelName.sup)

train.saveModel("data/fasttext/javamodel")

AgnewsTest.predict(train)

model.saveModel("data/fasttext/javamodel")

AgnewsTest.predict(model)


println("----------------")
//load model
val loadedModel = FastText.loadModel("data/fasttext/javamodel",false)
AgnewsTest.predict(loadedModel)
//
println("----------量化")

val qfst = FastText.quantize(model)

AgnewsTest.predict(qfst)

qfst.saveModel("data/fasttext/javamodel.qu")

println("-------------")

val loadedModel2 = FastText.loadModel("data/fasttext/javamodel.qu",false)
AgnewsTest.predict(loadedModel2)
}

object AgnewsTest{
Expand Down
52 changes: 3 additions & 49 deletions src/main/java/com/mayabot/mynlp/fasttext/Args.kt
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ class Args {
var verbose = 2
var lr = 0.05

var qout: Boolean = false

@Throws(IOException::class)
fun save(ofs: FileChannel) {

Expand Down Expand Up @@ -130,12 +128,12 @@ class TrainArgs {
/**
* size of word vectors [100]
*/
var dim: Int? = null
var dim: Int? = null

/**
* size of the context window [5]
*/
var ws: Int? = null
var ws: Int? = null

/**
* number of epochs [5]
Expand All @@ -162,50 +160,6 @@ class TrainArgs {
*/
var pretrainedVectors: String = ""

fun setLr(lr: Double?): TrainArgs {
this.lr = lr
return this
}

fun setLrUpdateRate(lrUpdateRate: Int?): TrainArgs {
this.lrUpdateRate = lrUpdateRate
return this
}

fun setDim(dim: Int?): TrainArgs {
this.dim = dim
return this
}

fun setWs(ws: Int?): TrainArgs {
this.ws = ws
return this
}

fun setEpoch(epoch: Int?): TrainArgs {
this.epoch = epoch
return this
}

fun setNeg(neg: Int?): TrainArgs {
this.neg = neg
return this
}

fun setLoss(loss: LossName): TrainArgs {
this.loss = loss
return this
}

fun setThread(thread: Int?): TrainArgs {
this.thread = thread
return this
}

fun setPretrainedVectors(pretrainedVectors: String): TrainArgs {
this.pretrainedVectors = pretrainedVectors
return this
}
}


Expand All @@ -230,7 +184,7 @@ enum class LossName private constructor(var value: Int) {
}


enum class ModelName private constructor(var value: Int) {
enum class ModelName constructor(var value: Int) {

/**
* CBOW
Expand Down
34 changes: 27 additions & 7 deletions src/main/java/com/mayabot/mynlp/fasttext/Dictionary.kt
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class Dictionary(private val args: Args) {
* @throws Exception
*/
@Throws(Exception::class)
fun buildFromFile(file: File) {
fun buildFromFile(file: TrainExampleSource) {

val mmm = 0.75 * MAX_VOCAB_SIZE

Expand All @@ -186,10 +186,10 @@ class Dictionary(private val args: Args) {
val splitter = Splitter.on(CharMatcher.whitespace())
.omitEmptyStrings().trimResults()

file.useLines { lines ->
lines.filterNot { it.isNullOrBlank() || it.startsWith("#") }
.forEach { line ->
splitter.split(line).forEach { token ->
val lines = file.iteratorAll()
lines.use {
it.forEach { line->
line.forEach { token->
add(token)
if (ntokens % 1000000 == 0L && args.verbose > 1) {
print("\rRead " + ntokens / 1000000 + "M words")
Expand All @@ -201,8 +201,7 @@ class Dictionary(private val args: Args) {
}
}
add(EOS)
}

}
threshold(args.minCount.toLong(), args.minCountLabel.toLong())

initTableDiscard()
Expand All @@ -220,6 +219,27 @@ class Dictionary(private val args: Args) {
}
}

//
// file.useLines { lines ->
// lines.filterNot { it.isNullOrBlank() || it.startsWith("#") }
// .forEach { line ->
// splitter.split(line).forEach { token ->
// add(token)
// if (ntokens % 1000000 == 0L && args.verbose > 1) {
// print("\rRead " + ntokens / 1000000 + "M words")
// }
//
// if (size > mmm) {
// minThreshold++
// threshold(minThreshold, minThreshold)
// }
// }
// add(EOS)
// }
//
//
// }


}

Expand Down
7 changes: 7 additions & 0 deletions src/main/java/com/mayabot/mynlp/fasttext/ExampleIterator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.mayabot.mynlp.fasttext;

import java.util.Iterator;
import java.util.List;

public interface ExampleIterator extends AutoCloseable, Iterator<List<String>> {
}
Loading

0 comments on commit ba942a4

Please sign in to comment.