-
Notifications
You must be signed in to change notification settings - Fork 43
/
model.py
41 lines (33 loc) · 1.73 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# !/usr/bin/env python
# -*- coding:utf-8 -*-
# @Time : 2019/12/3 7:15 下午
# @Author: wuchenglong
import tensorflow as tf
import tensorflow_addons as tf_ad
class NerModel(tf.keras.Model):
def __init__(self, hidden_num, vocab_size, label_size, embedding_size):
super(NerModel, self).__init__()
self.num_hidden = hidden_num
self.vocab_size = vocab_size
self.label_size = label_size
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_size)
self.biLSTM = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(hidden_num, return_sequences=True))
self.dense = tf.keras.layers.Dense(label_size)
self.transition_params = tf.Variable(tf.random.uniform(shape=(label_size, label_size)))
self.dropout = tf.keras.layers.Dropout(0.5)
# @tf.function
def call(self, text,labels=None,training=None):
text_lens = tf.math.reduce_sum(tf.cast(tf.math.not_equal(text, 0), dtype=tf.int32), axis=-1)
# -1 change 0
inputs = self.embedding(text)
inputs = self.dropout(inputs, training)
logits = self.dense(self.biLSTM(inputs))
if labels is not None:
label_sequences = tf.convert_to_tensor(labels, dtype=tf.int32)
log_likelihood, self.transition_params = tf_ad.text.crf_log_likelihood(logits,
label_sequences,
text_lens,
transition_params=self.transition_params)
return logits, text_lens, log_likelihood
else:
return logits, text_lens