1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
| class ZeroShotClassifier: """零样本图像分类""" def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer def classify(self, image, class_names): """ 零样本分类 Args: image: 输入图像 class_names: 类别名称列表,如["cat", "dog", "bird"] Returns: 预测的类别索引和置信度 """ with torch.no_grad(): image_features = self.model.encode_image(image) text_descriptions = [f"a photo of a {name}" for name in class_names] text_tokens = self.tokenizer(text_descriptions, padding=True, return_tensors='pt') with torch.no_grad(): text_features = self.model.encode_text(text_tokens) image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) values, indices = similarity[0].topk(1) return indices[0].item(), values[0].item()
def zero_shot_classification(model, image, class_names, tokenizer): """零样本分类函数""" classifier = ZeroShotClassifier(model, tokenizer) pred_idx, confidence = classifier.classify(image, class_names) return class_names[pred_idx], confidence
class_names = ["cat", "dog", "bird", "fish", "horse"] image = preprocess_image("test.jpg")
predicted_class, confidence = zero_shot_classification(model, image, class_names, tokenizer) print(f"Predicted: {predicted_class}, Confidence: {confidence:.4f}")
|