diff --git a/LanguageModel.lua b/LanguageModel.lua index d6248184..4ccf6755 100644 --- a/LanguageModel.lua +++ b/LanguageModel.lua @@ -162,6 +162,7 @@ function LM:sample(kwargs) local verbose = utils.get_kwarg(kwargs, 'verbose', 0) local sample = utils.get_kwarg(kwargs, 'sample', 1) local temperature = utils.get_kwarg(kwargs, 'temperature', 1) + local stream = utils.get_kwarg(kwargs, 'stream', 0) local sampled = torch.LongTensor(1, T) self:resetStates() @@ -172,6 +173,9 @@ function LM:sample(kwargs) print('Seeding with: "' .. start_text .. '"') end local x = self:encode_string(start_text):view(1, -1) + if stream == 1 then + io.write(start_text) + end local T0 = x:size(2) sampled[{{}, {1, T0}}]:copy(x) scores = self:forward(x)[{{}, {T0, T0}}] @@ -196,6 +200,9 @@ function LM:sample(kwargs) next_char = torch.multinomial(probs, 1):view(1, 1) end sampled[{{}, {t, t}}]:copy(next_char) + if stream == 1 then + io.write(self.idx_to_token[next_char[1][1]]) + end scores = self:forward(next_char) end diff --git a/doc/flags.md b/doc/flags.md index f2652bbf..4e7cbde3 100644 --- a/doc/flags.md +++ b/doc/flags.md @@ -57,3 +57,4 @@ The sampling script `sample.lua` accepts the following command-line flags: - `-gpu`: The ID of the GPU to use (zero-indexed). Default is 0. Set this to -1 to run in CPU-only mode. - `-gpu_backend`: The GPU backend to use; either `cuda` or `opencl`. Default is `cuda`. - `-verbose`: By default just the sampled text is printed to the console. Set this to 1 to also print some diagnostic information. +- `-stream`: By default the sampled text is buffered and printed in one go. Set this to 1 to disable buffering and stream the sampled text one character at a time. diff --git a/sample.lua b/sample.lua index 4e6ebae0..e98aa92e 100644 --- a/sample.lua +++ b/sample.lua @@ -13,6 +13,7 @@ cmd:option('-temperature', 1) cmd:option('-gpu', 0) cmd:option('-gpu_backend', 'cuda') cmd:option('-verbose', 0) +cmd:option('-stream', 0) local opt = cmd:parse(arg) @@ -39,4 +40,6 @@ if opt.verbose == 1 then print(msg) end model:evaluate() local sample = model:sample(opt) -print(sample) +if opt.stream == 0 then -- If streaming then sample has already been printed + print(sample) +end