-
Notifications
You must be signed in to change notification settings - Fork 267
fix(gh-2036): MyPy Errors in numpyro.distributions.transforms
Module
#2066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
This looks great, thanks! There are some minor errors numpyro/distributions/transforms.py:82: error: Missing positional argument "x" in call to "__call__" of "TransformT" [call-arg]
Installing missing stub packages:
numpyro/distributions/transforms.py:84: error: Name "inv" already defined on line 80 [no-redef]
numpyro/distributions/transforms.py:85: error: Incompatible types in assignment (expression has type "ReferenceType[None]", variable has type "TransformT | None") [assignment]
numpyro/distributions/transforms.py:86: error: Incompatible return value type (got "Array | Any | None", expected "TransformT") [return-value]
numpyro/distributions/transforms.py:1550: error: Item "ndarray[tuple[Any, ...], dtype[Any]]" of "ndarray[tuple[Any, ...], dtype[Any]] | Array" has no attribute "at" [union-attr] Do you need some help with these :) ? |
These errors are from numpyro/distributions/transforms.py#L77-L86, and I am not able to understand the significance of different conditions and the weak reference. I think you can look into this matter. I will take this one,
Thank you for offering help ❤️. |
ok! Sounds like a plan! I will try to look at it in the next days :) |
Hey @Qazalbash I gave it a try as in d57d9a6 . MyPy is happy now, maybe you can try it ? The only key change was |
@juanitorduz Thanks for the changes, Do I need to remove the plugin from |
It's depreciated so it's safe to remove |
…ng of inverse transforms Co-authored-by: Juan Orduz <juanitorduz@gmail.com>
Co-authored-by: Juan Orduz <juanitorduz@gmail.com>
Co-authored-by: Juan Orduz <juanitorduz@gmail.com>
ok! I think the tests are failing because a new NNX release and changes in The other tests |
Here is a patch for the first errors #2067 |
Here's another patch #2069 😸 |
x: Union[jax.Array, Any], | ||
y: Union[jax.Array, Any], | ||
intermediates: Optional[Any] = None, | ||
) -> Union[jax.Array, Any]: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we need those Union, shouldn't Any be enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. You want to see some Array there. How about using ArrayLike?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we change it to ArrayLike
, then we get these errors.
numpyro/distributions/transforms.py:181: error: Incompatible return value type (got "Array | ndarray[tuple[Any, ...], dtype[Any]] | numpy.bool[builtins.bool] | number[Any, int | float | complex] | int | float | complex | Any", expected "ndarray[tuple[Any, ...], dtype[Any]] | Array") [return-value]
numpyro/distributions/transforms.py:190: error: Unsupported operand type for unary - ("Array | ndarray[tuple[Any, ...], dtype[Any]] | numpy.bool[builtins.bool] | number[Any, int | float | complex] | int | float | complex | Any") [operator]
numpyro/distributions/transforms.py:190: error: Incompatible return value type (got "Array | ndarray[tuple[Any, ...], dtype[Any]] | Any | number[Any, int | float | complex] | int | float | complex", expected "ndarray[tuple[Any, ...], dtype[Any]] | Array") [return-value]
numpyro/distributions/transforms.py:366: error: Incompatible types in assignment (expression has type "Array | ndarray[tuple[Any, ...], dtype[Any]] | numpy.bool[builtins.bool] | number[Any, int | float | complex] | int | float | complex | Any", variable has type "ndarray[tuple[Any, ...], dtype[Any]] | Array") [assignment]
numpyro/distributions/transforms.py:371: error: Incompatible types in assignment (expression has type "Array | ndarray[tuple[Any, ...], dtype[Any]] | numpy.bool[builtins.bool] | number[Any, int | float | complex] | int | float | complex | Any", variable has type "ndarray[tuple[Any, ...], dtype[Any]] | Array") [assignment]
numpyro/distributions/transforms.py:397: error: Incompatible types in assignment (expression has type "Array | ndarray[tuple[Any, ...], dtype[Any]] | numpy.bool[builtins.bool] | number[Any, int | float | complex] | int | float | complex | Any", variable has type "ndarray[tuple[Any, ...], dtype[Any]] | Array") [assignment]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, thanks! How about using NumLike like in my comment above?
@@ -20,6 +22,9 @@ | |||
Message: TypeAlias = dict[str, Any] | |||
TraceT: TypeAlias = OrderedDict[str, Message] | |||
|
|||
# ArrayLike type has StaticScalar, StrictArrayT has everything except StaticScalars | |||
StrictArrayT = Union[np.ndarray, jax.Array] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why StrictArrayT is needed, could you elaborate? I think we can use ArrayLike instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some functions do not operate on boolean
or complex
; ArrayLike
contains boolean
and complex
types, and mypy was throwing an error over it. That's why I created StrictArrayT
, it is used where only arrays are required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we also want to support int, float at those places. Could you let me know which function has the issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using ArrayLike in most places and in specific cases, use NumLike? StrictArray can be used at those places which expect non-scalar shapes.
StrictArray = Union[np.ndarray, jax.Array]
NumLike = Union[
jax.Array,
np.ndarray, np.number,
bool, int, float, complex,
]
(we don't need to have T trailing in those cases I think)
…lasses and add PyTree type alias
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Qazalbash, I think we can get around issues by using NumLike just at some specific places. It's fine to use StrictArray at those arrays with dim >= 1. NonScalarArray is a good name for it I guess.
@@ -20,6 +22,9 @@ | |||
Message: TypeAlias = dict[str, Any] | |||
TraceT: TypeAlias = OrderedDict[str, Message] | |||
|
|||
# ArrayLike type has StaticScalar, StrictArrayT has everything except StaticScalars | |||
StrictArrayT = Union[np.ndarray, jax.Array] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using ArrayLike in most places and in specific cases, use NumLike? StrictArray can be used at those places which expect non-scalar shapes.
StrictArray = Union[np.ndarray, jax.Array]
NumLike = Union[
jax.Array,
np.ndarray, np.number,
bool, int, float, complex,
]
(we don't need to have T trailing in those cases I think)
def __call__(self, x: ArrayLike) -> ArrayLike: ... | ||
def _inverse(self, y: ArrayLike) -> ArrayLike: ... | ||
def __call__(self, x: Union[jax.Array, Any]) -> Union[jax.Array, Any]: ... | ||
def _inverse(self, y: Union[jax.Array, Any]) -> Union[jax.Array, Any]: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prefer: ArrayLike and not use Any in general
x: Union[jax.Array, Any], | ||
y: Union[jax.Array, Any], | ||
intermediates: Optional[Any] = None, | ||
) -> Union[jax.Array, Any]: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, thanks! How about using NumLike like in my comment above?
inv = self._inv() | ||
if inv is None: | ||
inv = _InverseTransform(self) | ||
self._inv = weakref.ref(inv) | ||
inv = cast(TransformT, _InverseTransform(self)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems unnecessary to me
inv = _InverseTransform(self) | ||
self._inv = weakref.ref(inv) | ||
inv = cast(TransformT, _InverseTransform(self)) | ||
self._inv = cast(TransformT, weakref.ref(inv)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this cast, it seems incorrect to me?
x: StrictArrayT, | ||
y: StrictArrayT, | ||
intermediates: Optional[PyTree] = None, | ||
) -> StrictArrayT: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe you can use ArrayLike at those places as long as the Protocol uses the type hint: log_abs_det_jacobian(...) -> NumLike:
@@ -191,7 +198,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: | |||
def tree_flatten(self): | |||
return (self._inv,), (("_inv",), dict()) | |||
|
|||
def __eq__(self, other: TransformT) -> bool: | |||
def __eq__(self, other: object) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious, why we allow to compare with non TransformT?
@@ -268,10 +275,13 @@ def __call__(self, x: ArrayLike) -> ArrayLike: | |||
return self.loc + self.scale * x | |||
|
|||
def _inverse(self, y: ArrayLike) -> ArrayLike: | |||
return (y - self.loc) / self.scale | |||
return (y - self.loc) / self.scale # type: ignore[call-overload,operator] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as long as loc
and scale
are NumLike
, this should be fine I think.
@@ -372,7 +385,7 @@ def log_abs_det_jacobian( | |||
) | |||
) | |||
|
|||
result = 0.0 | |||
result = jnp.zeros(()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if output is ArrayLike, I guess we can use 0.0 here.
def call_with_intermediates( | ||
self, x: ArrayLike | ||
) -> Tuple[ArrayLike, Optional[ArrayLike]]: | ||
def call_with_intermediates(self, x: ArrayLike) -> Tuple[ArrayLike, Sequence]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sequence -> PyTree
This PR contains the resolution of mypy errors passed by #2032, in
numpyro.distributions.transforms
module.There are two cases in particular which I am unable to resolve. You can see them by running the mypy.
log_abs_det_jacobian
of many transforms have unused parameters. I have typed them as union ofUnusedParam
and some appropriate numpy/jax type.Many cases were unresolvable, like,
__eq__
method expectsbool
as return type, but&
operation between arrays return array of type bool, which conflicts with the return type, therefore I have added the tag to ignore them. You will find similar tags in the file.