Skip to content

Commit 9a2b5a1

Browse files
authored
Merge pull request #342 from malmaud/sess_close
Define `Base.close(::Session)`
2 parents 5d40570 + f25ea1e commit 9a2b5a1

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

src/core.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,8 +554,7 @@ mutable struct Session
554554
this = new(ptr, graph)
555555
check_status(status)
556556
finalizer(this, self->begin
557-
status = Status()
558-
@tfcall(:TF_DeleteSession, Void, (Ptr{Void}, Ptr{Void}), self.ptr, status.ptr)
557+
close(self)
559558
end)
560559
return this
561560
end
@@ -571,6 +570,21 @@ mutable struct Session
571570
end
572571
end
573572

573+
"""
574+
close(sess::Session)
575+
576+
Closes the TensorFlow session, freeing the associated computational resources.
577+
"""
578+
function Base.close(sess::Session)
579+
if sess.ptr != C_NULL
580+
status = Status()
581+
@tfcall(:TF_DeleteSession, Void, (Ptr{Void}, Ptr{Void}), sess.ptr, status.ptr)
582+
check_status(status)
583+
sess.ptr = C_NULL
584+
end
585+
return nothing
586+
end
587+
574588

575589
mutable struct Buffer
576590
ptr::Ptr{Void}

src/run.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,18 @@ function build_input(tensor_map::Dict)
7777
input_tensors, input_values
7878
end
7979

80+
struct ClosedSessionError <: Exception
81+
end
82+
83+
function Base.show(io::IO, err::ClosedSessionError)
84+
print(io, "An operation was attempted on a closed TensorFlow session.")
85+
end
86+
8087
function run(sess::Session, inputs, input_values, outputs, targets)
8188
#Low level run, without size checking, and type conversion etc.
82-
89+
if sess.ptr == C_NULL
90+
throw(ClosedSessionError())
91+
end
8392
status = Status()
8493
output_values = fill(C_NULL, length(outputs))
8594
input_tensors = [RawTensor(x) for x in input_values]
@@ -184,6 +193,9 @@ end
184193

185194

186195
"""
196+
run(sess::Session, output, input_dict::Dict)
197+
198+
187199
Compute the result of one of more operations in the computation graph.
188200
"""
189201
function run(sess::Session, output, input_dict)

test/core.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ end
2626
end
2727
end
2828

29+
@testset "Session closing" begin
30+
session = tf.Session(Graph())
31+
x = constant(1)
32+
@test run(session, x) == 1
33+
close(session)
34+
close(session) # Test that we can safely call `close` twice on the same session
35+
@test_throws tf.ClosedSessionError run(session, x)
36+
end
37+
2938
@testset "get_operations" begin
3039
let
3140
graph = Graph()

0 commit comments

Comments
 (0)