Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/src/juliacall-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ jl.Vector[jl.Int]()
```

Some Julia types can be converted to corresponding numpy dtypes like `numpy.dtype(jl.Int)`.
Supports primitive types: `Bool`, `IntXX`, `UIntXX`, `FloatXX`, `ComplexFXX`,
`NumpyDates.InlineDateTime64{unit}` and `NumpyDates.InlineTimeDelta64{unit}`. Also
supports tuples, named tuples and structs of these.
Supports `Bool`, `IntXX`, `UIntXX`, `FloatXX`, `ComplexFXX`,
`NumpyDates.InlineDateTime64{unit}` and `NumpyDates.InlineTimeDelta64{unit}`, plus
`Tuple`s and `NamedTuple`s of these.
`````

`````@customdoc
Expand Down
7 changes: 4 additions & 3 deletions src/JlWrap/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ pybufferformat(::Type{T}) where {T} =
T == Complex{Cdouble} ? "Zd" :
T == Bool ? "?" :
T == Ptr{Cvoid} ? "P" :
if isstructtype(T) && isconcretetype(T) && allocatedinline(T)
if (T <: Union{Tuple,NamedTuple}) && isstructtype(T) && isconcretetype(T) && allocatedinline(T)
n = fieldcount(T)
flds = []
for i = 1:n
Expand Down Expand Up @@ -234,7 +234,7 @@ pyjlarray_isarrayabletype(::Type{NamedTuple{names,types}}) where {names,types} =

const PYTYPESTRDESCR = IdDict{Type,Tuple{String,Py}}()

pytypestrdescr(::Type{T}) where {T} =
function pytypestrdescr(::Type{T}) where {T}
get!(PYTYPESTRDESCR, T) do
c = Utils.islittleendian() ? '<' : '>'
if T == Bool
Expand Down Expand Up @@ -275,7 +275,7 @@ pytypestrdescr(::Type{T}) where {T} =
u == NumpyDates.UNBOUND_UNITS ? "" :
m == 1 ? "[$(Symbol(u))]" : "[$(m)$(Symbol(u))]"
("$(c)$(tc)8$(us)", PyNULL)
elseif isstructtype(T) && isconcretetype(T) && Base.allocatedinline(T)
elseif (T <: Union{Tuple,NamedTuple}) && isstructtype(T) && isconcretetype(T) && Base.allocatedinline(T)
n = fieldcount(T)
flds = []
for i = 1:n
Expand All @@ -298,6 +298,7 @@ pytypestrdescr(::Type{T}) where {T} =
("", PyNULL)
end
end
end

pyjlarray_array__array(x::AbstractArray) = x isa Array ? Py(nothing) : pyjl(Array(x))
pyjlarray_array__pyobjectarray(x::AbstractArray) = pyjl(PyObjectArray(x))
Expand Down
49 changes: 38 additions & 11 deletions src/JlWrap/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,49 @@ function pyjltype_getitem(self::Type, k_)
end
end

const PYNUMPYDTYPE = IdDict{Type,Py}()

function pyjltype_numpy_dtype(self::Type)
typestr, descr = pytypestrdescr(self)
if isempty(typestr)
errset(pybuiltins.AttributeError, "__numpy_dtype__")
return PyNULL
ans = get!(PYNUMPYDTYPE, self) do
typestr, descr = pytypestrdescr(self)
# unsupported type
if typestr == ""
return PyNULL
end
np = pyimport("numpy")
# simple scalar type
if pyisnull(descr)
return np.dtype(typestr)
end
# We could juse use np.dtype(descr), but when there is padding, np.dtype(descr)
# changes the names of the padding fields from "" to "f{N}". Using this other
# dtype constructor avoids this issue and preserves the invariant:
# np.dtype(eltype(array)) == np.array(array).dtype
names = []
formats = []
offsets = []
for i = 1:fieldcount(self)
nm = fieldname(self, i)
push!(names, nm isa Integer ? "f$(nm-1)" : String(nm))
ts, ds = pytypestrdescr(fieldtype(self, i))
push!(formats, pyisnull(ds) ? ts : ds)
push!(offsets, fieldoffset(self, i))
end
return np.dtype(
pydict(
names = pylist(names),
formats = pylist(formats),
offsets = pylist(offsets),
itemsize = sizeof(self),
),
)
end
np = pyimport("numpy")
if pyisnull(descr)
return np.dtype(typestr)
else
return np.dtype(descr)
if pyisnull(ans)
errset(pybuiltins.AttributeError, "__numpy_dtype__")
end
return ans
end

pyjl_handle_error_type(::typeof(pyjltype_numpy_dtype), x, exc) = pybuiltins.AttributeError

function init_type()
jl = pyjuliacallmodule
pybuiltins.exec(
Expand Down
12 changes: 9 additions & 3 deletions test/JlWrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -510,22 +510,28 @@ end
(Tuple{Int32, Int32}, pylist([("f0", "int32"), ("f1", "int32")])),
(@NamedTuple{}, pylist()),
(@NamedTuple{x::Int32, y::Int32}, pylist([("x", "int32"), ("y", "int32")])),
(Pair{Int32, Int32}, pylist([("first", "int32"), ("second", "int32")])),
]
@test pyeq(Bool, pygetattr(pyjl(t), "__numpy_dtype__"), np.dtype(d))
@test pyeq(Bool, np.dtype(pyjl(t)), np.dtype(d))
@test pyeq(Bool, np.dtype(t), np.dtype(d))
# test the invariant np.dtype(eltype(array)) == np.array(array).dtype
@test isequal(np.dtype(t), np.array(t[]).dtype)
end

# unsupported cases
@testset "$t -> AttributeError" for t in [
# non-primitives or mutables
# structs / mutables
Pair,
Pair{Int,Int},
String,
Vector{Int},
# pointers
Ptr{Cvoid},
Ptr{Int},
# PyPtr specifically should NOT be interpreted as np.dtype("O")
PythonCall.C.PyPtr,
# tuples containing illegal things
Tuple{String},
Tuple{Pair{Int,Int}},
]
err = try
pygetattr(pyjl(t), "__numpy_dtype__")
Expand Down
Loading