本文旨在詳細(xì)闡述如何將vision transformer(vit)模型從單標(biāo)簽多分類任務(wù)轉(zhuǎn)換到多標(biāo)簽分類任務(wù)。核心內(nèi)容聚焦于損失函數(shù)的替換,從`crossentropyloss`轉(zhuǎn)向更適合多標(biāo)簽的`bcewithlogitsloss`,并深入探討多標(biāo)簽分類任務(wù)下模型輸出層、標(biāo)簽格式以及評估指標(biāo)的選擇與實現(xiàn),提供實用的代碼示例和注意事項,以確保模型能夠準(zhǔn)確有效地處理多標(biāo)簽數(shù)據(jù)。
在計算機視覺領(lǐng)域,許多實際應(yīng)用場景需要模型識別圖像中存在的多個獨立特征或類別,而非僅僅識別一個主要類別。例如,一張圖片可能同時包含“貓”、“狗”和“草地”等多個標(biāo)簽。這種任務(wù)被稱為多標(biāo)簽分類(Multi-label Classification),它與傳統(tǒng)的單標(biāo)簽多分類(Single-label Multi-class Classification)有著本質(zhì)的區(qū)別。對于Vision Transformer (ViT) 模型而言,從單標(biāo)簽任務(wù)遷移到多標(biāo)簽任務(wù),主要涉及損失函數(shù)、模型輸出層以及評估策略的調(diào)整。
傳統(tǒng)的單標(biāo)簽多分類任務(wù)通常使用torch.nn.CrossEntropyLoss作為損失函數(shù)。該損失函數(shù)內(nèi)部集成了LogSoftmax和NLLLoss,它期望模型的輸出是每個類別的原始分?jǐn)?shù)(logits),而標(biāo)簽是一個整數(shù),代表唯一的正確類別。然而,在多標(biāo)簽分類中,一個樣本可能同時屬于多個類別,因此CrossEntropyLoss不再適用。
替換為 BCEWithLogitsLoss
對于多標(biāo)簽分類任務(wù),標(biāo)準(zhǔn)的做法是使用二元交叉熵?fù)p失函數(shù)。torch.nn.BCEWithLogitsLoss是一個非常合適的選擇,它結(jié)合了Sigmoid激活函數(shù)和二元交叉熵?fù)p失(Binary Cross Entropy Loss)。
BCEWithLogitsLoss的優(yōu)勢在于:
模型輸出與標(biāo)簽格式
在多標(biāo)簽分類中,模型的輸出層需要進(jìn)行調(diào)整。如果原始模型用于單標(biāo)簽分類,其最后一層可能輸出一個與類別數(shù)量相等的logit向量,并通過Softmax激活函數(shù)進(jìn)行概率歸一化。對于多標(biāo)簽分類,模型最后一層也應(yīng)輸出一個與類別數(shù)量相等的logit向量,但不應(yīng)在其后接Softmax激活函數(shù)。這些原始的logits將直接輸入到BCEWithLogitsLoss中。
標(biāo)簽的格式也必須是多熱編碼(multi-hot encoding),即一個與類別數(shù)量相等的向量,其中1表示該類別存在,0表示不存在。此外,標(biāo)簽的數(shù)據(jù)類型必須是浮點型(torch.float),以匹配BCEWithLogitsLoss的輸入要求。
代碼示例:損失函數(shù)替換
假設(shè)我們有7個可能的類別,并且標(biāo)簽格式如 [0, 1, 1, 0, 0, 1, 0]。
import torch import torch.nn as nn # 假設(shè)模型輸出的原始logits (batch_size, num_classes) # 這里以一個batch_size為1的示例 num_classes = 7 model_output_logits = torch.randn(1, num_classes) # 模擬模型輸出的原始logits # 真實標(biāo)簽,必須是float類型且為多熱編碼 # 示例標(biāo)簽: [0, 1, 1, 0, 0, 1, 0] 表示第1, 2, 5個類別存在 true_labels = torch.tensor([[0, 1, 1, 0, 0, 1, 0]]).float() # 定義BCEWithLogitsLoss loss_function = nn.BCEWithLogitsLoss() # 計算損失 loss = loss_function(model_output_logits, true_labels) print(f"模型輸出 logits: {model_output_logits}") print(f"真實標(biāo)簽: {true_labels}") print(f"計算得到的損失: {loss.item()}") # 在訓(xùn)練循環(huán)中的應(yīng)用示例 # pred = model(images.to(device)) # 模型輸出原始logits # labels = labels.to(device).float() # 確保標(biāo)簽是float類型 # loss = loss_function(pred, labels) # loss.backward() # optimizer.step()
注意事項:
單標(biāo)簽分類任務(wù)通常使用準(zhǔn)確率(Accuracy)作為主要評估指標(biāo)。然而,在多標(biāo)簽分類中,由于一個樣本可能有多個正確標(biāo)簽,或者沒有標(biāo)簽,簡單的準(zhǔn)確率不再能全面反映模型性能。我們需要采用更細(xì)致的評估指標(biāo)。
獲取預(yù)測結(jié)果
BCEWithLogitsLoss處理的是原始logits,為了進(jìn)行評估,我們需要將這些logits轉(zhuǎn)換為二元預(yù)測(0或1)。這通常通過Sigmoid激活函數(shù)和設(shè)定一個閾值(threshold)來完成。
# 假設(shè) model_output_logits 是模型的原始輸出 # model_output_logits = torch.randn(1, num_classes) # 從上面示例延續(xù) # 將logits通過Sigmoid函數(shù)轉(zhuǎn)換為概率 probabilities = torch.sigmoid(model_output_logits) # 設(shè)定閾值,通常為0.5 threshold = 0.5 # 將概率轉(zhuǎn)換為二元預(yù)測 predictions = (probabilities > threshold).int() print(f"預(yù)測概率: {probabilities}") print(f"二元預(yù)測 (閾值={threshold}): {predictions}")
常用的多標(biāo)簽評估指標(biāo)
以下是多標(biāo)簽分類中常用的評估指標(biāo):
精確率(Precision)、召回率(Recall)和F1分?jǐn)?shù)(F1-score): 這些指標(biāo)可以針對每個類別獨立計算,也可以通過平均策略(Micro-average, Macro-average)進(jìn)行匯總。
漢明損失(Hamming Loss): 衡量預(yù)測錯誤的標(biāo)簽占總標(biāo)簽的比例。值越低越好。 Hamming Loss = (錯誤預(yù)測的標(biāo)簽數(shù)量) / (總標(biāo)簽數(shù)量)
Jaccard 指數(shù)(Jaccard Index / IoU): 衡量預(yù)測標(biāo)簽集合與真實標(biāo)簽集合的相似度。對于每個樣本,Jaccard指數(shù) = |預(yù)測標(biāo)簽 ∩ 真實標(biāo)簽| / |預(yù)測標(biāo)簽 ∪ 真實標(biāo)簽|。然后可以對所有樣本取平均。
平均準(zhǔn)確率(Average Precision, AP)和平均精度均值(Mean Average Precision, mAP): 在某些場景(如目標(biāo)檢測)中非常流行,但也可用于多標(biāo)簽分類。AP是PR曲線下的面積,mAP是所有類別AP的平均值。
使用 scikit-learn 進(jìn)行評估
scikit-learn庫提供了豐富的函數(shù)來計算這些指標(biāo)。
from sklearn.metrics import precision_score, recall_score, f1_score, hamming_loss, jaccard_score import numpy as np # 假設(shè)有多個樣本的預(yù)測和真實標(biāo)簽 # true_labels_np 和 predictions_np 都是 (num_samples, num_classes) 的二維數(shù)組 true_labels_np = np.array([ [0, 1, 1, 0, 0, 1, 0], [1, 0, 0, 1, 0, 0, 0], [0, 0, 1, 1, 1, 0, 0] ]) predictions_np = np.array([ [0, 1, 0, 0, 0, 1, 0], # 樣本0: 預(yù)測對2個,錯1個(少預(yù)測一個標(biāo)簽) [1, 1, 0, 0, 0, 0, 0], # 樣本1: 預(yù)測對1個,錯1個(多預(yù)測一個標(biāo)簽) [0, 0, 1, 1, 0, 0, 0] # 樣本2: 預(yù)測對2個,錯1個(少預(yù)測一個標(biāo)簽) ]) # 轉(zhuǎn)換為一維數(shù)組以便于部分scikit-learn函數(shù)處理(對于micro/macro平均) # 或者直接使用多維數(shù)組并指定average='samples'/'weighted'/'none' y_true_flat = true_labels_np.flatten() y_pred_flat = predictions_np.flatten() print(f"真實標(biāo)簽:\n{true_labels_np}") print(f"預(yù)測標(biāo)簽:\n{predictions_np}") # Micro-average F1-score micro_f1 = f1_score(true_labels_np, predictions_np, average='micro') print(f"Micro-average F1-score: {micro_f1:.4f}") # Macro-average F1-score macro_f1 = f1_score(true_labels_np, predictions_np, average='macro') print(f"Macro-average F1-score: {macro_f1:.4f}") # Per-class F1-score per_class_f1 = f1_score(true_labels_np, predictions_np, average=None) print(f"Per-class F1-score: {per_class_f1}") # Hamming Loss h_loss = hamming_loss(true_labels_np, predictions_np) print(f"Hamming Loss: {h_loss:.4f}") # Jaccard Score (Average over samples) # 注意:jaccard_score在多標(biāo)簽中默認(rèn)是average='binary',需要指定其他平均方式 jaccard = jaccard_score(true_labels_np, predictions_np, average='samples') print(f"Jaccard Score (Average over samples): {jaccard:.4f}")
評估流程建議: 在訓(xùn)練過程中,可以定期計算Micro-F1或Macro-F1作為監(jiān)控指標(biāo)。在模型訓(xùn)練完成后,進(jìn)行全面的評估,包括各項指標(biāo)的計算,并分析每個類別的性能。
將ViT模型從單標(biāo)簽多分類轉(zhuǎn)換為多標(biāo)簽分類,關(guān)鍵在于理解任務(wù)性質(zhì)的變化并進(jìn)行相應(yīng)的調(diào)整。核心步驟包括:
通過這些調(diào)整,ViT模型能夠有效地處理多標(biāo)簽分類任務(wù),從而在更復(fù)雜的實際應(yīng)用中發(fā)揮其強大的特征學(xué)習(xí)能力。
以上就是ViT多標(biāo)簽分類:損失函數(shù)與評估策略改造指南的詳細(xì)內(nèi)容,更多請關(guān)注php中文網(wǎng)其它相關(guān)文章!
每個人都需要一臺速度更快、更穩(wěn)定的 PC。隨著時間的推移,垃圾文件、舊注冊表數(shù)據(jù)和不必要的后臺進(jìn)程會占用資源并降低性能。幸運的是,許多工具可以讓 Windows 保持平穩(wěn)運行。
微信掃碼
關(guān)注PHP中文網(wǎng)服務(wù)號
QQ掃碼
加入技術(shù)交流群
Copyright 2014-2025 http://ipnx.cn/ All Rights Reserved | php.cn | 湘ICP備2023035733號