Skip to content

GPSampler crashes when torch default device is cuda #6113

@argusdusty

Description

@argusdusty

Expected behavior

Successful run to completion.

Environment

  • Optuna version:4.3.0
  • Python version:3.11.4
  • OS:Windows-10-10.0.26100-SP0
  • Torch version:2.6.0+cu124

Error messages, stack traces, or logs

...
  File "C:\Users\Argusdusty\AppData\Local\Programs\Python\Python311\Lib\site-packages\optuna\samplers\_gp\sampler.py", line 230, in sample_relative
    kernel_params = gp.fit_kernel_params(
                    ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Argusdusty\AppData\Local\Programs\Python\Python311\Lib\site-packages\optuna\_gp\gp.py", line 264, in fit_kernel_params
    return _fit_kernel_params(
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Argusdusty\AppData\Local\Programs\Python\Python311\Lib\site-packages\optuna\_gp\gp.py", line 183, in _fit_kernel_params
    np.log(initial_kernel_params.inverse_squared_lengthscales.detach().numpy()),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Argusdusty\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\utils\_device.py", line 104, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Steps to reproduce

  1. import torch; torch.set_default_device("cuda") with a CUDA-enabled PyTorch.
  2. Create study with sampler=optuna.samplers.GPSampler()
  3. study.optimize(...)
import optuna
import torch

torch.set_default_device("cuda")

def objective(trial):
    x = trial.suggest_float("x", -100, 100)
    return x ** 2

study = optuna.create_study(sampler=optuna.samplers.GPSampler())
study.optimize(objective, n_trials=100)

Additional context (optional)

A successful patch in my local workspace has been to modify all calls to torch.tensor, torch.ones, torch.empty, torch.eye, etc. (calls that create new tensors), under optuna/_gp/, to add device='cpu'

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugIssue/PR about behavior that is broken. Not for typos/examples/CI/test but for Optuna itself.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions