ctakes-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dlig...@apache.org
Subject svn commit: r1770222 - in /ctakes/trunk/ctakes-temporal/scripts/nn: lstm_classify_hybrid.py lstm_train_hybrid.py
Date Thu, 17 Nov 2016 15:44:58 GMT
Author: dligach
Date: Thu Nov 17 15:44:58 2016
New Revision: 1770222

URL: http://svn.apache.org/viewvc?rev=1770222&view=rev
Log:
lstm pos+token model

Added:
    ctakes/trunk/ctakes-temporal/scripts/nn/lstm_classify_hybrid.py
    ctakes/trunk/ctakes-temporal/scripts/nn/lstm_train_hybrid.py

Added: ctakes/trunk/ctakes-temporal/scripts/nn/lstm_classify_hybrid.py
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/scripts/nn/lstm_classify_hybrid.py?rev=1770222&view=auto
==============================================================================
--- ctakes/trunk/ctakes-temporal/scripts/nn/lstm_classify_hybrid.py (added)
+++ ctakes/trunk/ctakes-temporal/scripts/nn/lstm_classify_hybrid.py Thu Nov 17 15:44:58 2016
@@ -0,0 +1,81 @@
+#!python
+
+from keras.models import Sequential, model_from_json
+import numpy as np
+import et_cleartk_io as ctk_io
+import sys
+import os.path
+import pickle
+from keras.preprocessing.sequence import pad_sequences
+
+def main(args):
+    if len(args) < 1:
+        sys.stderr.write("Error - one required argument: <model directory>\n")
+        sys.exit(-1)
+    working_dir = args[0]
+
+    target_dir = 'ctakes-temporal/target/eval/thyme/train_and_test/event-event/'
+    model_dir = os.path.join(os.environ['CTAKES_ROOT'], target_dir)
+    maxlen   = pickle.load(open(os.path.join(model_dir, "maxlen.p"), "rb"))
+    word2int = pickle.load(open(os.path.join(model_dir, "word2int.p"), "rb"))
+    tag2int = pickle.load(open(os.path.join(model_dir, "tag2int.p"), "rb"))
+    label2int = pickle.load(open(os.path.join(model_dir, "label2int.p"), "rb"))
+    model = model_from_json(open(os.path.join(model_dir, "model_0.json")).read())
+    model.load_weights(os.path.join(model_dir, "model_0.h5"))
+
+    int2label = {}
+    for label, integer in label2int.items():
+      int2label[integer] = label
+
+    while True:
+        try:
+            line = sys.stdin.readline().rstrip()
+            if not line:
+                break
+
+            text, pos = line.strip().split('|')
+
+            tokens = []
+            for token in text.rstrip().split():
+                if token in word2int:
+                    tokens.append(word2int[token])
+                else:
+                    tokens.append(word2int['none'])
+
+            tags = []
+            for tag in pos.rstrip().split():
+                if tag in tag2int:
+                    tags.append(tag2int[tag])
+                else:
+                    tags.append(tag2int['oov_tag'])
+
+            if len(tokens) > maxlen:
+                tokens = tokens[0:maxlen]
+            if len(tags) > maxlen:
+                tags = tags[0:maxlen]
+
+            test_x1 = pad_sequences([tokens], maxlen=maxlen)
+            test_x2 = pad_sequences([tags], maxlen=maxlen)
+
+            test_xs = []
+            test_xs.append(test_x1)
+            test_xs.append(test_x1)
+
+            out = model.predict(test_xs, batch_size=50)[0]
+
+        except KeyboardInterrupt:
+            sys.stderr.write("Caught keyboard interrupt\n")
+            break
+
+        if line == '':
+            sys.stderr.write("Encountered empty string so exiting\n")
+            break
+
+        out_str = int2label[out.argmax()]
+        print out_str
+        sys.stdout.flush()
+
+    sys.exit(0)
+
+if __name__ == "__main__":
+    main(sys.argv[1:])

Added: ctakes/trunk/ctakes-temporal/scripts/nn/lstm_train_hybrid.py
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/scripts/nn/lstm_train_hybrid.py?rev=1770222&view=auto
==============================================================================
--- ctakes/trunk/ctakes-temporal/scripts/nn/lstm_train_hybrid.py (added)
+++ ctakes/trunk/ctakes-temporal/scripts/nn/lstm_train_hybrid.py Thu Nov 17 15:44:58 2016
@@ -0,0 +1,99 @@
+#!/usr/bin/env python
+
+import sklearn as sk
+import numpy as np
+np.random.seed(1337)
+import et_cleartk_io as ctk_io
+import nn_models
+import sys
+import os.path
+import dataset_hybrid
+import keras as k
+from keras.utils.np_utils import to_categorical
+from keras.optimizers import RMSprop
+from keras.preprocessing.sequence import pad_sequences
+from keras.models import Sequential
+from keras.layers import Merge
+from keras.layers.core import Dense, Activation
+from keras.layers import LSTM
+from keras.layers.embeddings import Embedding
+import pickle
+
+def main(args):
+    if len(args) < 1:
+        sys.stderr.write("Error - one required argument: <data directory>\n")
+        sys.exit(-1)
+    working_dir = args[0]
+    data_file = os.path.join(working_dir, 'training-data.liblinear')
+
+    # learn alphabet from training data
+    provider = dataset_hybrid.DatasetProvider(data_file)
+    # now load training examples and labels
+    train_x1, train_x2, train_y = provider.load(data_file)
+    # turn x and y into numpy array among other things
+    maxlen = max([len(seq) for seq in train_x1])
+    classes = len(set(train_y))
+
+    train_x1 = pad_sequences(train_x1, maxlen=maxlen)
+    train_x2 = pad_sequences(train_x2, maxlen=maxlen)
+    train_y = to_categorical(np.array(train_y), classes)
+
+    pickle.dump(maxlen, open(os.path.join(working_dir, 'maxlen.p'),"wb"))
+    pickle.dump(provider.word2int, open(os.path.join(working_dir, 'word2int.p'),"wb"))
+    pickle.dump(provider.tag2int, open(os.path.join(working_dir, 'tag2int.p'),"wb"))
+    pickle.dump(provider.label2int, open(os.path.join(working_dir, 'label2int.p'),"wb"))
+
+    print 'train_x1 shape:', train_x1.shape
+    print 'train_x2 shape:', train_x2.shape
+    print 'train_y shape:', train_y.shape
+
+    branches = [] # models to be merged
+    train_xs = [] # train x for each branch
+
+    branch1 = Sequential()
+    branch1.add(Embedding(len(provider.word2int),
+                          300,
+                          input_length=maxlen,
+                          dropout=0.25))
+    branch1.add(LSTM(128,
+                dropout_W = 0.20,
+                dropout_U = 0.20))
+
+    branches.append(branch1)
+    train_xs.append(train_x1)
+
+    branch2 = Sequential()
+    branch2.add(Embedding(len(provider.tag2int),
+                          300,
+                          input_length=maxlen,
+                          dropout=0.25))
+    branch2.add(LSTM(128,
+                dropout_W = 0.20,
+                dropout_U = 0.20))
+
+    branches.append(branch2)
+    train_xs.append(train_x2)
+
+    model = Sequential()
+    model.add(Merge(branches, mode='concat'))
+    model.add(Dense(classes))
+    model.add(Activation('softmax'))
+
+    optimizer = RMSprop(lr=0.001, rho=0.9, epsilon=1e-08)
+    model.compile(loss='categorical_crossentropy',
+                  optimizer=optimizer,
+                  metrics=['accuracy'])
+    model.fit(train_xs,
+              train_y,
+              nb_epoch=3,
+              batch_size=50,
+              verbose=0,
+              validation_split=0.1)
+
+    json_string = model.to_json()
+    open(os.path.join(working_dir, 'model_0.json'), 'w').write(json_string)
+    model.save_weights(os.path.join(working_dir, 'model_0.h5'), overwrite=True)
+    sys.exit(0)
+
+if __name__ == "__main__":
+    main(sys.argv[1:])



Mime
View raw message