Skip to content

Conversation

Qazalbash
Copy link
Contributor

This PR contains the resolution of mypy errors passed by #2032, in numpyro.distributions.transforms module.

There are two cases in particular which I am unable to resolve. You can see them by running the mypy.

log_abs_det_jacobian of many transforms have unused parameters. I have typed them as union of UnusedParam and some appropriate numpy/jax type.

Many cases were unresolvable, like, __eq__ method expects bool as return type, but & operation between arrays return array of type bool, which conflicts with the return type, therefore I have added the tag to ignore them. You will find similar tags in the file.

@juanitorduz
Copy link
Contributor

This looks great, thanks! There are some minor errors

numpyro/distributions/transforms.py:82: error: Missing positional argument "x" in call to "__call__" of "TransformT"  [call-arg]
Installing missing stub packages:
numpyro/distributions/transforms.py:84: error: Name "inv" already defined on line 80  [no-redef]
numpyro/distributions/transforms.py:85: error: Incompatible types in assignment (expression has type "ReferenceType[None]", variable has type "TransformT | None")  [assignment]
numpyro/distributions/transforms.py:86: error: Incompatible return value type (got "Array | Any | None", expected "TransformT")  [return-value]
numpyro/distributions/transforms.py:1550: error: Item "ndarray[tuple[Any, ...], dtype[Any]]" of "ndarray[tuple[Any, ...], dtype[Any]] | Array" has no attribute "at"  [union-attr]

Do you need some help with these :) ?

@Qazalbash
Copy link
Contributor Author

Qazalbash commented Aug 29, 2025

numpyro/distributions/transforms.py:82: error: Missing positional argument "x" in call to "__call__" of "TransformT"  [call-arg]
Installing missing stub packages:
numpyro/distributions/transforms.py:84: error: Name "inv" already defined on line 80  [no-redef]
numpyro/distributions/transforms.py:85: error: Incompatible types in assignment (expression has type "ReferenceType[None]", variable has type "TransformT | None")  [assignment]
numpyro/distributions/transforms.py:86: error: Incompatible return value type (got "Array | Any | None", expected "TransformT")  [return-value]

These errors are from numpyro/distributions/transforms.py#L77-L86, and I am not able to understand the significance of different conditions and the weak reference. I think you can look into this matter.

I will take this one,

numpyro/distributions/transforms.py:1550: error: Item "ndarray[tuple[Any, ...], dtype[Any]]" of "ndarray[tuple[Any, ...], dtype[Any]] | Array" has no attribute "at"  [union-attr]

Thank you for offering help ❤️.

@juanitorduz
Copy link
Contributor

juanitorduz commented Aug 29, 2025

ok! Sounds like a plan! I will try to look at it in the next days :)

@juanitorduz
Copy link
Contributor

Hey @Qazalbash I gave it a try as in d57d9a6 . MyPy is happy now, maybe you can try it ? The only key change was self.inv() -> self.inv which I think makes more sense, let's see if the tests complain ;)

juanitorduz referenced this pull request Aug 30, 2025
@Qazalbash
Copy link
Contributor Author

@juanitorduz Thanks for the changes, mypy is happy now.

Do I need to remove the plugin from pyproject.toml?

@juanitorduz
Copy link
Contributor

It's depreciated so it's safe to remove

Qazalbash and others added 2 commits August 30, 2025 23:22
…ng of inverse transforms

Co-authored-by: Juan Orduz <juanitorduz@gmail.com>
Co-authored-by: Juan Orduz <juanitorduz@gmail.com>
Co-authored-by: Juan Orduz <juanitorduz@gmail.com>
@Qazalbash Qazalbash requested a review from fehiepsi September 4, 2025 12:04
@juanitorduz
Copy link
Contributor

juanitorduz commented Sep 4, 2025

ok! I think the tests are failing because a new NNX release and changes in nnx.merge, see https://github.com/google/flax/releases/tag/v0.11.2

The other tests FAILED test/test_distributions.py::test_entropy_samples I am not sure about.

@juanitorduz
Copy link
Contributor

Here is a patch for the first errors #2067

@Qazalbash
Copy link
Contributor Author

Here's another patch #2069 😸

x: Union[jax.Array, Any],
y: Union[jax.Array, Any],
intermediates: Optional[Any] = None,
) -> Union[jax.Array, Any]: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need those Union, shouldn't Any be enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we only use Any, then IDEs show types like this,
image
otherwise,
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. You want to see some Array there. How about using ArrayLike?

Copy link
Contributor Author

@Qazalbash Qazalbash Sep 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we change it to ArrayLike, then we get these errors.

numpyro/distributions/transforms.py:181: error: Incompatible return value type (got "Array | ndarray[tuple[Any, ...], dtype[Any]] | numpy.bool[builtins.bool] | number[Any, int | float | complex] | int | float | complex | Any", expected "ndarray[tuple[Any, ...], dtype[Any]] | Array")  [return-value]
numpyro/distributions/transforms.py:190: error: Unsupported operand type for unary - ("Array | ndarray[tuple[Any, ...], dtype[Any]] | numpy.bool[builtins.bool] | number[Any, int | float | complex] | int | float | complex | Any")  [operator]
numpyro/distributions/transforms.py:190: error: Incompatible return value type (got "Array | ndarray[tuple[Any, ...], dtype[Any]] | Any | number[Any, int | float | complex] | int | float | complex", expected "ndarray[tuple[Any, ...], dtype[Any]] | Array")  [return-value]
numpyro/distributions/transforms.py:366: error: Incompatible types in assignment (expression has type "Array | ndarray[tuple[Any, ...], dtype[Any]] | numpy.bool[builtins.bool] | number[Any, int | float | complex] | int | float | complex | Any", variable has type "ndarray[tuple[Any, ...], dtype[Any]] | Array")  [assignment]
numpyro/distributions/transforms.py:371: error: Incompatible types in assignment (expression has type "Array | ndarray[tuple[Any, ...], dtype[Any]] | numpy.bool[builtins.bool] | number[Any, int | float | complex] | int | float | complex | Any", variable has type "ndarray[tuple[Any, ...], dtype[Any]] | Array")  [assignment]
numpyro/distributions/transforms.py:397: error: Incompatible types in assignment (expression has type "Array | ndarray[tuple[Any, ...], dtype[Any]] | numpy.bool[builtins.bool] | number[Any, int | float | complex] | int | float | complex | Any", variable has type "ndarray[tuple[Any, ...], dtype[Any]] | Array")  [assignment]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, thanks! How about using NumLike like in my comment above?

@@ -20,6 +22,9 @@
Message: TypeAlias = dict[str, Any]
TraceT: TypeAlias = OrderedDict[str, Message]

# ArrayLike type has StaticScalar, StrictArrayT has everything except StaticScalars
StrictArrayT = Union[np.ndarray, jax.Array]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why StrictArrayT is needed, could you elaborate? I think we can use ArrayLike instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some functions do not operate on boolean or complex; ArrayLike contains boolean and complex types, and mypy was throwing an error over it. That's why I created StrictArrayT, it is used where only arrays are required.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we also want to support int, float at those places. Could you let me know which function has the issue?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using ArrayLike in most places and in specific cases, use NumLike? StrictArray can be used at those places which expect non-scalar shapes.

StrictArray = Union[np.ndarray, jax.Array]
NumLike = Union[
    jax.Array,
    np.ndarray, np.number,
    bool, int, float, complex,
]

(we don't need to have T trailing in those cases I think)

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Qazalbash, I think we can get around issues by using NumLike just at some specific places. It's fine to use StrictArray at those arrays with dim >= 1. NonScalarArray is a good name for it I guess.

@@ -20,6 +22,9 @@
Message: TypeAlias = dict[str, Any]
TraceT: TypeAlias = OrderedDict[str, Message]

# ArrayLike type has StaticScalar, StrictArrayT has everything except StaticScalars
StrictArrayT = Union[np.ndarray, jax.Array]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using ArrayLike in most places and in specific cases, use NumLike? StrictArray can be used at those places which expect non-scalar shapes.

StrictArray = Union[np.ndarray, jax.Array]
NumLike = Union[
    jax.Array,
    np.ndarray, np.number,
    bool, int, float, complex,
]

(we don't need to have T trailing in those cases I think)

def __call__(self, x: ArrayLike) -> ArrayLike: ...
def _inverse(self, y: ArrayLike) -> ArrayLike: ...
def __call__(self, x: Union[jax.Array, Any]) -> Union[jax.Array, Any]: ...
def _inverse(self, y: Union[jax.Array, Any]) -> Union[jax.Array, Any]: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefer: ArrayLike and not use Any in general

x: Union[jax.Array, Any],
y: Union[jax.Array, Any],
intermediates: Optional[Any] = None,
) -> Union[jax.Array, Any]: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, thanks! How about using NumLike like in my comment above?

inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
self._inv = weakref.ref(inv)
inv = cast(TransformT, _InverseTransform(self))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems unnecessary to me

inv = _InverseTransform(self)
self._inv = weakref.ref(inv)
inv = cast(TransformT, _InverseTransform(self))
self._inv = cast(TransformT, weakref.ref(inv))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this cast, it seems incorrect to me?

x: StrictArrayT,
y: StrictArrayT,
intermediates: Optional[PyTree] = None,
) -> StrictArrayT:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you can use ArrayLike at those places as long as the Protocol uses the type hint: log_abs_det_jacobian(...) -> NumLike:

@@ -191,7 +198,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
def tree_flatten(self):
return (self._inv,), (("_inv",), dict())

def __eq__(self, other: TransformT) -> bool:
def __eq__(self, other: object) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, why we allow to compare with non TransformT?

@@ -268,10 +275,13 @@ def __call__(self, x: ArrayLike) -> ArrayLike:
return self.loc + self.scale * x

def _inverse(self, y: ArrayLike) -> ArrayLike:
return (y - self.loc) / self.scale
return (y - self.loc) / self.scale # type: ignore[call-overload,operator]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as long as loc and scale are NumLike, this should be fine I think.

@@ -372,7 +385,7 @@ def log_abs_det_jacobian(
)
)

result = 0.0
result = jnp.zeros(())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if output is ArrayLike, I guess we can use 0.0 here.

def call_with_intermediates(
self, x: ArrayLike
) -> Tuple[ArrayLike, Optional[ArrayLike]]:
def call_with_intermediates(self, x: ArrayLike) -> Tuple[ArrayLike, Sequence]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sequence -> PyTree

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants