Skip to content

Commit

Permalink
Merge pull request #491 from malmaud/eager_mode
Browse files Browse the repository at this point in the history
Eager mode
  • Loading branch information
malmaud committed Mar 26, 2019
2 parents 859d2b4 + 0492959 commit 0635115
Show file tree
Hide file tree
Showing 25 changed files with 111,326 additions and 2,559 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ version = "0.12.0"

[deps]
AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
CRC32c = "8bf52ea8-c179-5cab-976a-9e18b702a9bc"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Expand Down
4 changes: 2 additions & 2 deletions deps/build.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using PyCall
using Conda

const cur_version = "1.12.0"
const cur_py_version = "1.12.0"
const cur_version = "1.13.1"
const cur_py_version = "1.12" # Temporarily downgrade Python version until 1.13.1 is released on Conda


############################
Expand Down
8 changes: 8 additions & 0 deletions examples/diffeq.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using DifferentialEquations

f(u,p,t)=1.01 .* u

u0=constant(0.5)
tspan=(0.0,1.0)
prob=ODEProblem(f, u0, tspan)
s=solve(prob)
13 changes: 13 additions & 0 deletions examples/keras.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using TensorFlow
tf=TensorFlow
tf.enable_eager_execution()
m = tf.Sequential()

tf.add(m, tf.Dense(3,10))
tf.add(m, tf.Relu())
tf.add(m, tf.Dense(10, 3))

x=constant(randn(5,3))
y=3x+5
tf.compile(m, optimizer=tf.SGD(lr=1e-4), loss=tf.mse)
tf.fit(m, x, y, n_epochs=200)
17 changes: 16 additions & 1 deletion src/TensorFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,14 @@ Ops,
slice,
import_op,
@tfimport,
tf_versioninfo
tf_versioninfo,
copy_to_device,
enable_eager_execution,
EagerTensor,
summary,
create_tape,
set_tape,
with_tape


using Distributed
Expand All @@ -141,8 +148,13 @@ function deallocator(data, len, arg)

end

include("context.jl")

function __init__()
c_deallocator[] = @cfunction(deallocator, Cvoid, (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}))
for context in default_context()
push!(global_context, context)
end
end

function load_python_process(;force_reload=false)
Expand Down Expand Up @@ -198,6 +210,7 @@ include("meta.jl")
include("constants.jl")
include("tensorflow_protos.jl")
include("core.jl")
include("eager.jl")
include("run.jl")
include("version.jl")
include("ops.jl")
Expand All @@ -211,5 +224,7 @@ include("summary.jl")
include("deprecated.jl")
include("show.jl")
include("generate_ops.jl")
include("tape.jl")
include("keras.jl")

end
44 changes: 44 additions & 0 deletions src/context.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
abstract type Context
end

struct ContextStack
contexts::Vector{Context}
end

ContextStack() = ContextStack(Context[])

function Base.push!(stack::ContextStack, context::Context)
push!(stack.contexts, context)
end

function Base.pop!(stack::ContextStack)
pop!(stack.contexts)
end

function default_context()
return [ExecutionMode(eager=false)]
end

function context_value(context_type)
return global_context[context_type]
end

function Base.getindex(c::ContextStack, context_type)
value = nothing
for context in c.contexts
if isa(context, context_type)
value = context
end
end
return value
end

function with_context(block, ctx)
push!(global_context, ctx)
res = block()
# This assumes the block doesn't adjust the context. We should pop explicitly the pushed context.
pop!(global_context)
return res
end

const global_context = ContextStack()
41 changes: 23 additions & 18 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -507,20 +507,20 @@ end
mutable struct DeviceList
ptr::Ptr{Cvoid}
count::Int
end

function DeviceList(s::Session)
status = Status()
ptr = @tfcall(:TF_SessionListDevices, Ptr{Cvoid},
(Ptr{Cvoid}, Ptr{Cvoid}), s, status)
check_status(status)
count = @tfcall(:TF_DeviceListCount, Cint, (Ptr{Cvoid},),
ptr)
this = new(ptr, count)
finalizer(this) do self
close(self)
end
this
function DeviceList(s::Session)
status = Status()
ptr = @tfcall(:TF_SessionListDevices, Ptr{Cvoid},
(Ptr{Cvoid}, Ptr{Cvoid}), s, status)
check_status(status)
count = @tfcall(:TF_DeviceListCount, Cint, (Ptr{Cvoid},),
ptr)
this = DeviceList(ptr, count)
finalizer(this) do self
close(self)
end
this
end

struct DeviceInfo
Expand Down Expand Up @@ -663,6 +663,8 @@ RawTensor(data::AbstractArray) = RawTensor(collect(data))

RawTensor(t::RawTensor) = t

Base.unsafe_convert(::Type{Ptr{Cvoid}}, t::RawTensor) = t.ptr

function varint_encode(b::IO, n::Integer)
while n 2^7
write(b, UInt8(0b10000000 | (n & 0b1111111)))
Expand Down Expand Up @@ -803,7 +805,7 @@ function Base.sizeof(t::RawTensor)
@tfcall(:TF_TensorByteSize, Csize_t, (Ptr{Cvoid},), t.ptr) |> Int
end

function set_device(node_desc, device::String)
function set_device(node_desc, device)
@tfcall(:TF_SetDevice, Cvoid,
(Ptr{Cvoid}, Cstring),
node_desc.ptr, device)
Expand Down Expand Up @@ -1168,7 +1170,10 @@ function load_proto(value::tensorflow.AttrValue)
load_proto(value.list)
elseif has_field(value, :_type)
type_ = value._type
proto_type_map[type_]
get(proto_type_map, type_) do
@warn "Unrecognized type. Defaulting to Float32." type_
Float32
end
end
end

Expand Down Expand Up @@ -1218,10 +1223,6 @@ Represents the output of an operation in the computation graph
value_index::Int
end

get_graph(t::AbstractTensor) = Tensor(t).op.graph

node_name(t::AbstractTensor) = (node_name(Tensor(t).op), Tensor(t).value_index)

function Tensor(op::Operation, value_index::Int)
base_tensor = Tensor{Any}(op, value_index)
Tensor{get_output_type(base_tensor)}(op, value_index)
Expand All @@ -1242,6 +1243,10 @@ Base.convert(::Type{Tensor{Any}}, value::Tensor{R}) where {R} = value

Base.convert(::Type{Tensor{T}}, value) where {T} = convert(Tensor{T}, constant(value))

get_graph(t::AbstractTensor) = Tensor(t).op.graph

node_name(t::AbstractTensor) = (node_name(Tensor(t).op), Tensor(t).value_index)

function operation_output_type(port::Port)
@tfcall(:TF_OperationOutputType, TF_DataType, (Port,), port)
end
Expand Down
Loading

0 comments on commit 0635115

Please sign in to comment.