329 lines
12 KiB
Python
329 lines
12 KiB
Python
import math
|
|
from typing import List, Optional, TypeAlias, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
from transformers import AutoModelForImageTextToText, AutoProcessor
|
|
|
|
ImageInput: TypeAlias = Union[Image.Image, List[Image.Image]]
|
|
BatchImageInput: TypeAlias = Union[List[Image.Image], List[List[Image.Image]]]
|
|
|
|
|
|
class OpsMMEmbeddingV1(nn.Module):
|
|
def __init__(
|
|
self,
|
|
model_name: str,
|
|
device: str = "cuda",
|
|
max_length: Optional[int] = None,
|
|
attn_implementation: Optional[str] = None,
|
|
device_map: Optional[str] = None,
|
|
load_in_4bit: bool = False,
|
|
load_in_8bit: bool = False,
|
|
torch_dtype: Optional[torch.dtype] = torch.bfloat16,
|
|
):
|
|
super().__init__()
|
|
self.device = device
|
|
self.max_length = max_length
|
|
self.default_instruction = "You are a helpful assistant."
|
|
from transformers import AutoModelForImageTextToText
|
|
load_kwargs = dict(
|
|
torch_dtype=torch_dtype,
|
|
low_cpu_mem_usage=True,
|
|
attn_implementation=attn_implementation,
|
|
)
|
|
# Quantization options (requires bitsandbytes for 4/8-bit)
|
|
if load_in_4bit:
|
|
load_kwargs["load_in_4bit"] = True
|
|
if load_in_8bit:
|
|
load_kwargs["load_in_8bit"] = True
|
|
if device_map is not None:
|
|
load_kwargs["device_map"] = device_map
|
|
|
|
self.base_model = AutoModelForImageTextToText.from_pretrained(
|
|
model_name,
|
|
**load_kwargs,
|
|
)
|
|
# Only move to a single device when not using tensor-parallel sharding
|
|
if device_map is None:
|
|
self.base_model = self.base_model.to(self.device)
|
|
|
|
self.processor = AutoProcessor.from_pretrained(model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28)
|
|
self.processor.tokenizer.padding_side = "left"
|
|
self.eval()
|
|
|
|
def encode_input(self, input):
|
|
hidden_states = self.base_model(**input, return_dict=True, output_hidden_states=True)
|
|
hidden_states = hidden_states.hidden_states[-1]
|
|
pooled_output = self._pooling(hidden_states)
|
|
return pooled_output
|
|
|
|
def _pooling(self, last_hidden_state):
|
|
batch_size = last_hidden_state.shape[0]
|
|
reps = last_hidden_state[torch.arange(batch_size), -1, :]
|
|
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
|
|
return reps
|
|
|
|
def _validate_instructions(
|
|
self,
|
|
texts: Optional[List[str]],
|
|
images: Optional[BatchImageInput],
|
|
instruction: Optional[Union[str, List[str]]],
|
|
) -> List[str]:
|
|
"""Validate and format instructions to match batch size"""
|
|
batch_size = max(len(x) if x is not None else 0 for x in [texts, images])
|
|
|
|
if instruction is None:
|
|
return [self.default_instruction] * batch_size
|
|
|
|
if isinstance(instruction, str):
|
|
return [instruction] * batch_size
|
|
|
|
if isinstance(instruction, list):
|
|
if len(instruction) != batch_size:
|
|
raise ValueError(f"Length of instruction list ({len(instruction)}) must match batch size ({batch_size}) when texts/images are provided")
|
|
return instruction
|
|
|
|
raise TypeError("instruction must be str, List[str] or None")
|
|
|
|
def _process_images(self, images: ImageInput) -> List[Image.Image]:
|
|
"""Convert single image or list of images to processed format"""
|
|
if isinstance(images, Image.Image) or isinstance(images, str):
|
|
return [fetch_image(images)]
|
|
return [fetch_image(i) for i in images]
|
|
|
|
def embed(
|
|
self,
|
|
texts: Optional[List[str]] = None,
|
|
images: Optional[BatchImageInput] = None,
|
|
instruction: Optional[Union[str, List[str]]] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
"""Generate embeddings for text, images, or combined inputs.
|
|
|
|
Args:
|
|
texts: List of text inputs (optional)
|
|
images: Can be:
|
|
- List[Image.Image]: Single image per input
|
|
- List[List[Image.Image]]: Multiple images per input
|
|
instruction: Instruction(s) for the model. Can be:
|
|
- None: use default instruction
|
|
- str: use same instruction for all inputs
|
|
- List[str]: per-input instructions (must match batch size)
|
|
"""
|
|
if texts is None and images is None:
|
|
raise ValueError("Either texts or images must be provided")
|
|
|
|
instructions = self._validate_instructions(texts, images, instruction)
|
|
|
|
# Determine batch size
|
|
batch_size = len(texts) if texts is not None else len(images) # type: ignore
|
|
|
|
input_texts, input_images = [], []
|
|
for i in range(batch_size):
|
|
text = texts[i] if texts is not None else None
|
|
image = images[i] if images is not None else None
|
|
|
|
input_str = ""
|
|
processed_image = None
|
|
if image is not None:
|
|
processed_image = self._process_images(image)
|
|
input_str += "<|vision_start|><|image_pad|><|vision_end|>" * len(processed_image)
|
|
|
|
if text is not None:
|
|
input_str += text
|
|
|
|
msg = f"<|im_start|>system\n{instructions[i]}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
|
|
|
|
input_texts.append(msg)
|
|
input_images.append(processed_image)
|
|
|
|
# Only pass to processor if we actually have images
|
|
processed_images = input_images if any(img is not None for img in input_images) else None
|
|
|
|
inputs = self.processor(
|
|
text=input_texts,
|
|
images=processed_images,
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=self.max_length,
|
|
return_tensors="pt",
|
|
)
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
with torch.inference_mode():
|
|
embeddings = self.encode_input(inputs)
|
|
|
|
return embeddings
|
|
|
|
def get_text_embeddings(
|
|
self,
|
|
texts: List[str],
|
|
instruction: Optional[Union[str, List[str]]] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
"""Convenience method for text-only embeddings"""
|
|
return self.get_fused_embeddings(texts=texts, instruction=instruction, **kwargs)
|
|
|
|
def get_image_embeddings(
|
|
self,
|
|
images: BatchImageInput,
|
|
instruction: Optional[Union[str, List[str]]] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
"""Convenience method for image-only embeddings.
|
|
|
|
Args:
|
|
images: Can be:
|
|
- List[Image.Image]: Single image per input
|
|
- List[List[Image.Image]]: Multiple images per input
|
|
"""
|
|
return self.get_fused_embeddings(images=images, instruction=instruction, **kwargs)
|
|
|
|
def get_fused_embeddings(
|
|
self,
|
|
texts: Optional[List[str]] = None,
|
|
images: Optional[BatchImageInput] = None,
|
|
instruction: Optional[Union[str, List[str]]] = None,
|
|
batch_size: int = 8,
|
|
show_progress: bool = True,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
"""Batch processing for large collections of texts/images.
|
|
|
|
Args:
|
|
texts: List of text inputs (optional)
|
|
images: Can be:
|
|
- List[Image.Image]: Single image per input
|
|
- List[List[Image.Image]]: Multiple images per input
|
|
instruction: Instruction(s) for the model
|
|
batch_size: Number of items to process at once
|
|
show_progress: Whether to display progress bar
|
|
"""
|
|
|
|
if texts is None and images is None:
|
|
raise ValueError("Either texts or images must be provided")
|
|
|
|
total_items = len(texts) if texts is not None else len(images) # type: ignore
|
|
num_batches = math.ceil(total_items / batch_size)
|
|
|
|
all_embeddings = []
|
|
progress = tqdm(total=num_batches, disable=not show_progress, desc="Processing")
|
|
|
|
for i in range(0, total_items, batch_size):
|
|
batch_texts = texts[i : i + batch_size] if texts is not None else None
|
|
batch_images = images[i : i + batch_size] if images is not None else None
|
|
batch_emb = self.embed(texts=batch_texts, images=batch_images, instruction=instruction)
|
|
|
|
all_embeddings.append(batch_emb.cpu())
|
|
progress.update(1)
|
|
|
|
progress.close()
|
|
return torch.cat(all_embeddings, dim=0).to(self.device)
|
|
|
|
def forward(self, **inputs) -> torch.Tensor:
|
|
"""Alias for encode_input"""
|
|
return self.encode_input(inputs)
|
|
|
|
|
|
### Modified from qwen_vl_utils.vision_process.py
|
|
import base64
|
|
import logging
|
|
import math
|
|
from io import BytesIO
|
|
|
|
import requests
|
|
|
|
IMAGE_FACTOR = 28
|
|
MIN_PIXELS = 256 * 28 * 28
|
|
MAX_PIXELS = 1280 * 28 * 28
|
|
MAX_RATIO = 200
|
|
|
|
|
|
def round_by_factor(number: int, factor: int) -> int:
|
|
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
|
return round(number / factor) * factor
|
|
|
|
|
|
def ceil_by_factor(number: int | float, factor: int) -> int:
|
|
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
|
return math.ceil(number / factor) * factor
|
|
|
|
|
|
def floor_by_factor(number: int | float, factor: int) -> int:
|
|
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
|
return math.floor(number / factor) * factor
|
|
|
|
|
|
def smart_resize(
|
|
height: int,
|
|
width: int,
|
|
factor: int = IMAGE_FACTOR,
|
|
min_pixels: int = MIN_PIXELS,
|
|
max_pixels: int = MAX_PIXELS,
|
|
) -> tuple[int, int]:
|
|
"""
|
|
Rescales the image so that the following conditions are met:
|
|
1. Both dimensions (height and width) are divisible by 'factor'.
|
|
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
|
3. The aspect ratio of the image is maintained as closely as possible.
|
|
"""
|
|
h_bar = max(factor, round_by_factor(height, factor))
|
|
w_bar = max(factor, round_by_factor(width, factor))
|
|
if h_bar * w_bar > max_pixels:
|
|
beta = math.sqrt((height * width) / max_pixels)
|
|
h_bar = floor_by_factor(height / beta, factor)
|
|
w_bar = floor_by_factor(width / beta, factor)
|
|
elif h_bar * w_bar < min_pixels:
|
|
beta = math.sqrt(min_pixels / (height * width))
|
|
h_bar = ceil_by_factor(height * beta, factor)
|
|
w_bar = ceil_by_factor(width * beta, factor)
|
|
|
|
if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO:
|
|
logging.warning(f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}")
|
|
if h_bar > w_bar:
|
|
h_bar = w_bar * MAX_RATIO
|
|
else:
|
|
w_bar = h_bar * MAX_RATIO
|
|
return h_bar, w_bar
|
|
|
|
|
|
def fetch_image(
|
|
image: str | Image.Image,
|
|
size_factor: int = IMAGE_FACTOR,
|
|
min_pixels: int = MIN_PIXELS,
|
|
max_pixels: int = MAX_PIXELS,
|
|
) -> Image.Image:
|
|
image_obj = None
|
|
if isinstance(image, Image.Image):
|
|
image_obj = image
|
|
elif image.startswith("http://") or image.startswith("https://"):
|
|
image_obj = Image.open(requests.get(image, stream=True).raw) # type: ignore
|
|
elif image.startswith("file://"):
|
|
image_obj = Image.open(image[7:])
|
|
elif image.startswith("data:image"):
|
|
if "base64," in image:
|
|
_, base64_data = image.split("base64,", 1)
|
|
data = base64.b64decode(base64_data)
|
|
image_obj = Image.open(BytesIO(data))
|
|
else:
|
|
image_obj = Image.open(image)
|
|
if image_obj is None:
|
|
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
|
|
image = image_obj.convert("RGB")
|
|
width, height = image.size
|
|
resized_height, resized_width = smart_resize(
|
|
height,
|
|
width,
|
|
factor=size_factor,
|
|
min_pixels=min_pixels,
|
|
max_pixels=max_pixels,
|
|
)
|
|
image = image.resize((resized_width, resized_height))
|
|
|
|
return image
|
|
|
|
|
|
###
|