I downloaded the 03. distinguishing_particles_in_brightfield_tutorial.ipynb from the tutorial and ran it locally. However, I encountered the following error when executing the code:
model = dl.Model(
net,
train_data=data_pipeline,
val_data=data_pipeline,
loss=dl.torch.nn.CrossEntropyLoss(),
optimizer=dl.Adam(lr=1e-3),
metrics=[tm.F1Score(task="multiclass", num_classes=3)],
)
AttributeError: module 'deeptrack.deeplay' has no attribute 'Model'
To resolve this, I modified the code to use dl.Regressor instead of dl.Model:
model = dl.Regressor(
net,
loss=dl.torch.nn.CrossEntropyLoss(),
optimizer=dl.Adam(lr=1e-3),
metrics=[tm.F1Score(task="multiclass", num_classes=3)],
)
This works, but then I encounter another error when executing the following code:
input_image, target_image = data_pipeline.batch(4)
predicted_image = model.predict(input_image.astype(np.float32)).softmax(1)
RuntimeError: Given groups=1, weight of size [32, 1, 3, 3], expected input[4, 128, 128, 1] to have 1 channels, but got 128 channels instead
My pytorch version is 2.2.1 and python version is 3.11.5.
- Is my modification to use
dl.Regressor instead of dl.Model correct? If yes, how can I handle the RuntimeError related to channel mismatch in the following code?
- In the deeplay tutorial, it's mentioned that the .fit() method handles training, validation, and logging, and also selects the best device (GPU if available). However, it seems to always train on the CPU. How can I force the training loop to run on the GPU?
Thanks in advance for your help!
I downloaded the
03. distinguishing_particles_in_brightfield_tutorial.ipynbfrom the tutorial and ran it locally. However, I encountered the following error when executing the code:To resolve this, I modified the code to use dl.Regressor instead of dl.Model:
This works, but then I encounter another error when executing the following code:
My pytorch version is 2.2.1 and python version is 3.11.5.
dl.Regressorinstead ofdl.Model correct? If yes, how can I handle the RuntimeError related to channel mismatch in the following code?Thanks in advance for your help!