一、train中遇到的问题

(一)python、pytorch、cuda版本不对应

swin-unet官方仓库上写的使用的是python3.7运行的代码,所以我一开始把环境全部朝python3.7去配置。却一直报错。

经过一番搜索后,发现python3.7对应的环境无法在4060laptop上运行。

在多次尝试不同的环境,并结合b站复现别的论文的视频,选择将python版本改为3.8。

1、新建独立环境

1
2
conda create -n py.8 python=3.8  # 明确指定Python 3.8
conda activate py.8

2、使用pip绕过conda依赖限制

1
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu118

(二)一堆cuda的报错

image-20250228152925629

根据github中issue的讨论,获得修改方法,train.py中的num_classes和n_class都要设置为9

https://github.com/HuCaoFighting/Swin-Unet/issues/121

(三)安装完requirements.txt中的库后仍然缺少部分库

根据搜索安装即可

(四)训练集、验证集地址、名称问题

对trainer.py中的相关代码进行如下修改

image-20250228153722709

(五)windows中不能使用多线程

二、test中遇到的问题

(一)找不到best_model.pth.txt文件

image-20250228220644939

(二)文件地址错乱

(一)(二)的解决方法相同:

代码中的volum_path统一改为root_path,然后根据报错提示修改对应的地址。

(三)维度出现问题

修改utils.py的代码

原代码:

1
2
image, label = image.squeeze(0).cpu().detach().numpy().squeeze(0), label.squeeze(0).cpu().detach().numpy().squeeze(0)

修改后:

1
2
3
4
5
6
7
8
9
10
image = image.cpu().detach().numpy()
label = label.cpu().detach().numpy()

if image.shape[0] == 1:
image = image.squeeze(0)
if label.shape[0] == 1:
label = label.squeeze(0)

#image, label = image.squeeze(0).cpu().detach().numpy().squeeze(0), label.squeeze(0).cpu().detach().numpy().squeeze(0)

参考:

https://juejin.cn/post/7431728417744175154

三、test结果

第六类不知为啥数值都是0…

image-20250228234137594

四、总结

这是我第一次尝试复现代码,用时一天半终于把环境配好,第一次成功运行代码。

用时6:16:35训练完成!!!

image-20250228211915969