加载和准备 PTB 数据集
首先导入模块并加载数据如下::
from datasetslib.ptb import PTBSimple
ptb = PTBSimple()
# downloads data, converts words to ids, converts files to a list of ids
ptb.load_data()
print('Train :',ptb.part['train'][0:5])
print('Test: ',ptb.part['test'][0:5])
print('Valid: ',ptb.part['valid'][0:5])
print('Vocabulary Length = ',ptb.vocab_len)
每个数据集的前五个元素以及词汇长度打印如下:
Train : [9970, 9971, 9972, 9974, 9975]
Test: [102, 14, 24, 32, 752]
Valid: [1132, 93, 358, 5, 329]
Vocabulary Length = 10000
我们将上下文窗口设置为两个单词并获得 CBOW 对:
ptb.skip_window=2
ptb.reset_index_in_epoch()
# in CBOW input is the context word and output is the target word
y_batch, x_batch = ptb.next_batch_cbow()
print('The CBOW pairs : context,target')
for i in range(5 * ptb.skip_window):
print('(', [ptb.id2word[x_i] for x_i in x_batch[i]],
',', y_batch[i], ptb.id2word[y_batch[i]], ')')
输出是:
The CBOW pairs : context,target
( ['aer', 'banknote', 'calloway', 'centrust'] , 9972 berlitz )
( ['banknote', 'berlitz', 'centrust', 'cluett'] , 9974 calloway )
( ['berlitz', 'calloway', 'cluett', 'fromstein'] , 9975 centrust )
( ['calloway', 'centrust', 'fromstein', 'gitano'] , 9976 cluett )
( ['centrust', 'cluett', 'gitano', 'guterman'] , 9980 fromstein )
( ['cluett', 'fromstein', 'guterman', 'hydro-quebec'] , 9981 gitano )
( ['fromstein', 'gitano', 'hydro-quebec', 'ipo'] , 9982 guterman )
( ['gitano', 'guterman', 'ipo', 'kia'] , 9983 hydro-quebec )
( ['guterman', 'hydro-quebec', 'kia', 'memotec'] , 9984 ipo )
( ['hydro-quebec', 'ipo', 'memotec', 'mlx'] , 9986 kia )
现在让我们看看 skip-gram 对:
ptb.skip_window=2
ptb.reset_index_in_epoch()
# in skip-gram input is the target word and output is the context word
x_batch, y_batch = ptb.next_batch()
print('The skip-gram pairs : target,context')
for i in range(5 * ptb.skip_window):
print('(',x_batch[i], ptb.id2word[x_batch[i]],
',', y_batch[i], ptb.id2word[y_batch[i]],')')
输出为:
The skip-gram pairs : target,context
( 9972 berlitz , 9970 aer )
( 9972 berlitz , 9971 banknote )
( 9972 berlitz , 9974 calloway )
( 9972 berlitz , 9975 centrust )
( 9974 calloway , 9971 banknote )
( 9974 calloway , 9972 berlitz )
( 9974 calloway , 9975 centrust )
( 9974 calloway , 9976 cluett )
( 9975 centrust , 9972 berlitz )
( 9975 centrust , 9974 calloway )