diff --git a/__pycache__/baidu_vdb_backend.cpython-310.pyc b/__pycache__/baidu_vdb_backend.cpython-310.pyc index 2c09ce0..4b40d89 100644 Binary files a/__pycache__/baidu_vdb_backend.cpython-310.pyc and b/__pycache__/baidu_vdb_backend.cpython-310.pyc differ diff --git a/__pycache__/faiss_vector_store.cpython-310.pyc b/__pycache__/faiss_vector_store.cpython-310.pyc new file mode 100644 index 0000000..1f97413 Binary files /dev/null and b/__pycache__/faiss_vector_store.cpython-310.pyc differ diff --git a/__pycache__/multimodal_retrieval_faiss.cpython-310.pyc b/__pycache__/multimodal_retrieval_faiss.cpython-310.pyc new file mode 100644 index 0000000..593e644 Binary files /dev/null and b/__pycache__/multimodal_retrieval_faiss.cpython-310.pyc differ diff --git a/__pycache__/multimodal_retrieval_local.cpython-310.pyc b/__pycache__/multimodal_retrieval_local.cpython-310.pyc new file mode 100644 index 0000000..8093d61 Binary files /dev/null and b/__pycache__/multimodal_retrieval_local.cpython-310.pyc differ diff --git a/__pycache__/multimodal_retrieval_vdb.cpython-310.pyc b/__pycache__/multimodal_retrieval_vdb.cpython-310.pyc index 4e8ceb8..e92d963 100644 Binary files a/__pycache__/multimodal_retrieval_vdb.cpython-310.pyc and b/__pycache__/multimodal_retrieval_vdb.cpython-310.pyc differ diff --git a/__pycache__/optimized_file_handler.cpython-310.pyc b/__pycache__/optimized_file_handler.cpython-310.pyc index 63aa8f1..06082b9 100644 Binary files a/__pycache__/optimized_file_handler.cpython-310.pyc and b/__pycache__/optimized_file_handler.cpython-310.pyc differ diff --git a/__pycache__/proxy_utils.cpython-310.pyc b/__pycache__/proxy_utils.cpython-310.pyc new file mode 100644 index 0000000..38cffb9 Binary files /dev/null and b/__pycache__/proxy_utils.cpython-310.pyc differ diff --git a/app_log.txt b/app_log.txt new file mode 100644 index 0000000..2f676af --- /dev/null +++ b/app_log.txt @@ -0,0 +1,78 @@ +nohup: ignoring input +INFO:baidu_bos_manager:✅ BOS连接测试成功 +INFO:baidu_bos_manager:✅ BOS客户端初始化成功: dmtyz-demo +INFO:mongodb_manager:✅ MongoDB连接成功: mmeb +INFO:mongodb_manager:✅ MongoDB索引创建完成 +INFO:__main__:初始化多模态检索系统... +INFO:multimodal_retrieval_local:使用GPU: [0, 1] +INFO:multimodal_retrieval_local:加载本地模型和处理器: /root/models/Ops-MM-embedding-v1-7B +The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release. +You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0. +INFO:multimodal_retrieval_local:Processor类型: +INFO:multimodal_retrieval_local:Processor方法: ['__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_auto_class', '_check_special_mm_tokens', '_create_repo', '_get_arguments_from_pretrained', '_get_files_timestamps', '_get_num_multimodal_tokens', '_merge_kwargs', '_upload_modified_files', 'apply_chat_template', 'attributes', 'audio_tokenizer', 'batch_decode', 'chat_template', 'check_argument_for_proper_class', 'decode', 'feature_extractor_class', 'from_args_and_dict', 'from_pretrained', 'get_possibly_dynamic_module', 'get_processor_dict', 'image_processor', 'image_processor_class', 'image_token', 'image_token_id', 'model_input_names', 'optional_attributes', 'optional_call_args', 'post_process_image_text_to_text', 'push_to_hub', 'register_for_auto_class', 'save_pretrained', 'to_dict', 'to_json_file', 'to_json_string', 'tokenizer', 'tokenizer_class', 'validate_init_kwargs', 'video_processor', 'video_processor_class', 'video_token', 'video_token_id'] +INFO:multimodal_retrieval_local:Image processor类型: +INFO:multimodal_retrieval_local:Image processor方法: ['__backends', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slotnames__', '__str__', '__subclasshook__', '__weakref__', '_auto_class', '_create_repo', '_further_process_kwargs', '_fuse_mean_std_and_rescale_factor', '_get_files_timestamps', '_prepare_image_like_inputs', '_prepare_images_structure', '_preprocess', '_preprocess_image_like_inputs', '_process_image', '_processor_class', '_set_processor_class', '_upload_modified_files', '_valid_kwargs_names', '_validate_preprocess_kwargs', 'center_crop', 'compile_friendly_resize', 'convert_to_rgb', 'crop_size', 'data_format', 'default_to_square', 'device', 'disable_grouping', 'do_center_crop', 'do_convert_rgb', 'do_normalize', 'do_rescale', 'do_resize', 'fetch_images', 'filter_out_unused_kwargs', 'from_dict', 'from_json_file', 'from_pretrained', 'get_image_processor_dict', 'get_number_of_image_patches', 'image_mean', 'image_processor_type', 'image_std', 'input_data_format', 'max_pixels', 'merge_size', 'min_pixels', 'model_input_names', 'normalize', 'patch_size', 'preprocess', 'push_to_hub', 'register_for_auto_class', 'resample', 'rescale', 'rescale_and_normalize', 'rescale_factor', 'resize', 'return_tensors', 'save_pretrained', 'size', 'temporal_patch_size', 'to_dict', 'to_json_file', 'to_json_string', 'unused_kwargs', 'valid_kwargs'] + Loading checkpoint shards: 0%| | 0/4 [00:00 +INFO:multimodal_retrieval_local:encode_image: 图像列表,长度: 1 +INFO:multimodal_retrieval_local:encode_image: 处理图像输入 +INFO:multimodal_retrieval_local:encode_image: 图像 0 格式: JPEG, 模式: RGB, 大小: (939, 940) +INFO:multimodal_retrieval_local:encode_image: 使用image_processor处理图像 +INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 05:52:40] "GET / HTTP/1.1" 200 - +INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 05:52:41] "GET /api/system_info HTTP/1.1" 200 - +INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 05:52:41] "GET /api/system_info HTTP/1.1" 200 - +INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 05:52:42] "GET /favicon.ico HTTP/1.1" 404 - +INFO:multimodal_retrieval_local:encode_image: 处理后的输入键: ['pixel_values'] +INFO:__main__:处理图像: 微信图片_20250910164839_1_13.jpg (99396 字节) +INFO:__main__:成功加载图像: 20250910164839_1_13.jpg, 格式: JPEG, 模式: RGB, 大小: (939, 940) +INFO:multimodal_retrieval_local:add_images: 开始添加图像,数量: 1 +INFO:multimodal_retrieval_local:add_images: 编码图像 +INFO:multimodal_retrieval_local:encode_image: 开始编码图像,类型: +INFO:multimodal_retrieval_local:encode_image: 图像列表,长度: 1 +INFO:multimodal_retrieval_local:encode_image: 处理图像输入 +INFO:multimodal_retrieval_local:encode_image: 图像 0 格式: JPEG, 模式: RGB, 大小: (939, 940) +INFO:multimodal_retrieval_local:encode_image: 使用image_processor处理图像 +INFO:multimodal_retrieval_local:encode_image: 运行模型推理 +INFO:multimodal_retrieval_local:Model类型: +INFO:multimodal_retrieval_local:Model属性: ['T_destination', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_auto_class', '_backward_compatibility_gradient_checkpointing', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_can_compile_fullgraph', '_can_record_outputs', '_can_set_attn_implementation', '_check_and_adjust_attn_implementation', '_checkpoint_conversion_mapping', '_compiled_call_impl', '_convert_head_mask_to_5d', '_copy_lm_head_original_to_resized', '_create_repo', '_dispatch_accelerate_model', '_fix_state_dict_key_on_load', '_fix_state_dict_key_on_save', '_fix_state_dict_keys_on_save', '_flash_attn_2_can_dispatch', '_flash_attn_3_can_dispatch', '_flex_attn_can_dispatch', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_from_config', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_files_timestamps', '_get_key_renaming_mapping', '_get_name', '_get_no_split_modules', '_get_resized_embeddings', '_get_resized_lm_head', '_hf_hook', '_hf_peft_config_loaded', '_hook_rss_memory_post_forward', '_hook_rss_memory_pre_forward', '_init_added_embeddings_weights_with_mean', '_init_added_lm_head_bias_with_mean', '_init_added_lm_head_weights_with_mean', '_init_weights', '_initialize_missing_keys', '_initialize_weights', '_input_embed_layer', '_is_full_backward_hook', '_is_hf_initialized', '_is_stateful', '_keep_in_fp32_modules', '_keep_in_fp32_modules', '_keep_in_fp32_modules_strict', '_keep_in_fp32_modules_strict', '_keys_to_ignore_on_load_missing', '_keys_to_ignore_on_load_unexpected', '_keys_to_ignore_on_save', '_load_from_flax', '_load_from_state_dict', '_load_from_tf', '_load_pretrained_model', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '_move_missing_keys_from_meta_to_cpu', '_named_members', '_no_split_modules', '_no_split_modules', '_non_persistent_buffers_set', '_old_forward', '_parameters', '_pp_plan', '_pp_plan', '_register_load_state_dict_pre_hook', '_register_state_dict_hook', '_replicate_for_data_parallel', '_resize_token_embeddings', '_save_to_state_dict', '_sdpa_can_dispatch', '_set_default_torch_dtype', '_set_gradient_checkpointing', '_skip_keys_device_placement', '_slow_forward', '_state_dict_hooks', '_state_dict_pre_hooks', '_supports_attention_backend', '_supports_flash_attn', '_supports_flex_attn', '_supports_sdpa', '_tie_encoder_decoder_weights', '_tie_or_clone_weights', '_tied_weights_keys', '_tp_plan', '_tp_size', '_upload_modified_files', '_version', '_wrapped_call_impl', 'active_adapter', 'active_adapters', 'add_adapter', 'add_memory_hooks', 'add_model_tags', 'add_module', 'apply', 'base_model', 'base_model_prefix', 'bfloat16', 'buffers', 'call_super_init', 'can_generate', 'can_record_outputs', 'children', 'compile', 'config', 'config_class', 'cpu', 'create_extended_attention_mask_for_decoder', 'cuda', 'cuda', 'delete_adapter', 'dequantize', 'device', 'disable_adapters', 'disable_input_require_grads', 'double', 'dtype', 'dummy_inputs', 'dump_patches', 'enable_adapters', 'enable_input_require_grads', 'estimate_tokens', 'eval', 'extra_repr', 'float', 'floating_point_ops', 'forward', 'forward', 'framework', 'from_pretrained', 'generation_config', 'get_adapter_state_dict', 'get_buffer', 'get_compiled_call', 'get_correct_attn_implementation', 'get_decoder', 'get_extended_attention_mask', 'get_extra_state', 'get_head_mask', 'get_image_features', 'get_init_context', 'get_input_embeddings', 'get_memory_footprint', 'get_output_embeddings', 'get_parameter', 'get_parameter_or_buffer', 'get_placeholder_mask', 'get_position_embeddings', 'get_rope_index', 'get_submodule', 'get_video_features', 'gradient_checkpointing_disable', 'gradient_checkpointing_enable', 'half', 'hf_device_map', 'init_weights', 'initialize_weights', 'invert_attention_mask', 'ipu', 'is_backend_compatible', 'is_gradient_checkpointing', 'is_parallelizable', 'language_model', 'load_adapter', 'load_state_dict', 'loss_function', 'loss_type', 'main_input_name', 'model_tags', 'modules', 'mtia', 'name_or_path', 'named_buffers', 'named_children', 'named_modules', 'named_parameters', 'num_parameters', 'parameters', 'post_init', 'prune_heads', 'push_to_hub', 'register_backward_hook', 'register_buffer', 'register_for_auto_class', 'register_forward_hook', 'register_forward_pre_hook', 'register_full_backward_hook', 'register_full_backward_pre_hook', 'register_load_state_dict_post_hook', 'register_load_state_dict_pre_hook', 'register_module', 'register_parameter', 'register_state_dict_post_hook', 'register_state_dict_pre_hook', 'requires_grad_', 'reset_memory_hooks_state', 'resize_position_embeddings', 'resize_token_embeddings', 'retrieve_modules_from_names', 'reverse_bettertransformer', 'rope_deltas', 'save_pretrained', 'set_adapter', 'set_attn_implementation', 'set_decoder', 'set_extra_state', 'set_input_embeddings', 'set_output_embeddings', 'set_submodule', 'share_memory', 'smart_apply', 'state_dict', 'supports_gradient_checkpointing', 'supports_pp_plan', 'supports_tp_plan', 'tie_weights', 'to', 'to', 'to_bettertransformer', 'to_empty', 'tp_size', 'train', 'training', 'type', 'visual', 'warn_if_padding_and_no_attention_mask', 'warnings_issued', 'xpu', 'zero_grad'] +ERROR:multimodal_retrieval_local:encode_image: 处理图像时出错: embedding(): argument 'indices' (position 2) must be Tensor, not NoneType +ERROR:multimodal_retrieval_local:add_images: 图像编码失败,返回空数组 +INFO:multimodal_retrieval_local:索引保存成功: /root/mmeb/local_faiss_index.index +INFO:multimodal_retrieval_local:元数据保存成功: /root/mmeb/local_faiss_index_metadata.json +INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 05:52:46] "POST /api/add_image HTTP/1.1" 200 - +INFO:multimodal_retrieval_local:encode_image: 处理后的输入键: ['pixel_values'] +INFO:multimodal_retrieval_local:encode_image: 运行模型推理 +INFO:multimodal_retrieval_local:Model类型: +INFO:multimodal_retrieval_local:Model属性: ['T_destination', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_auto_class', '_backward_compatibility_gradient_checkpointing', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_can_compile_fullgraph', '_can_record_outputs', '_can_set_attn_implementation', '_check_and_adjust_attn_implementation', '_checkpoint_conversion_mapping', '_compiled_call_impl', '_convert_head_mask_to_5d', '_copy_lm_head_original_to_resized', '_create_repo', '_dispatch_accelerate_model', '_fix_state_dict_key_on_load', '_fix_state_dict_key_on_save', '_fix_state_dict_keys_on_save', '_flash_attn_2_can_dispatch', '_flash_attn_3_can_dispatch', '_flex_attn_can_dispatch', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_from_config', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_files_timestamps', '_get_key_renaming_mapping', '_get_name', '_get_no_split_modules', '_get_resized_embeddings', '_get_resized_lm_head', '_hf_hook', '_hf_peft_config_loaded', '_hook_rss_memory_post_forward', '_hook_rss_memory_pre_forward', '_init_added_embeddings_weights_with_mean', '_init_added_lm_head_bias_with_mean', '_init_added_lm_head_weights_with_mean', '_init_weights', '_initialize_missing_keys', '_initialize_weights', '_input_embed_layer', '_is_full_backward_hook', '_is_hf_initialized', '_is_stateful', '_keep_in_fp32_modules', '_keep_in_fp32_modules', '_keep_in_fp32_modules_strict', '_keep_in_fp32_modules_strict', '_keys_to_ignore_on_load_missing', '_keys_to_ignore_on_load_unexpected', '_keys_to_ignore_on_save', '_load_from_flax', '_load_from_state_dict', '_load_from_tf', '_load_pretrained_model', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '_move_missing_keys_from_meta_to_cpu', '_named_members', '_no_split_modules', '_no_split_modules', '_non_persistent_buffers_set', '_old_forward', '_parameters', '_pp_plan', '_pp_plan', '_register_load_state_dict_pre_hook', '_register_state_dict_hook', '_replicate_for_data_parallel', '_resize_token_embeddings', '_save_to_state_dict', '_sdpa_can_dispatch', '_set_default_torch_dtype', '_set_gradient_checkpointing', '_skip_keys_device_placement', '_slow_forward', '_state_dict_hooks', '_state_dict_pre_hooks', '_supports_attention_backend', '_supports_flash_attn', '_supports_flex_attn', '_supports_sdpa', '_tie_encoder_decoder_weights', '_tie_or_clone_weights', '_tied_weights_keys', '_tp_plan', '_tp_size', '_upload_modified_files', '_version', '_wrapped_call_impl', 'active_adapter', 'active_adapters', 'add_adapter', 'add_memory_hooks', 'add_model_tags', 'add_module', 'apply', 'base_model', 'base_model_prefix', 'bfloat16', 'buffers', 'call_super_init', 'can_generate', 'can_record_outputs', 'children', 'compile', 'config', 'config_class', 'cpu', 'create_extended_attention_mask_for_decoder', 'cuda', 'cuda', 'delete_adapter', 'dequantize', 'device', 'disable_adapters', 'disable_input_require_grads', 'double', 'dtype', 'dummy_inputs', 'dump_patches', 'enable_adapters', 'enable_input_require_grads', 'estimate_tokens', 'eval', 'extra_repr', 'float', 'floating_point_ops', 'forward', 'forward', 'framework', 'from_pretrained', 'generation_config', 'get_adapter_state_dict', 'get_buffer', 'get_compiled_call', 'get_correct_attn_implementation', 'get_decoder', 'get_extended_attention_mask', 'get_extra_state', 'get_head_mask', 'get_image_features', 'get_init_context', 'get_input_embeddings', 'get_memory_footprint', 'get_output_embeddings', 'get_parameter', 'get_parameter_or_buffer', 'get_placeholder_mask', 'get_position_embeddings', 'get_rope_index', 'get_submodule', 'get_video_features', 'gradient_checkpointing_disable', 'gradient_checkpointing_enable', 'half', 'hf_device_map', 'init_weights', 'initialize_weights', 'invert_attention_mask', 'ipu', 'is_backend_compatible', 'is_gradient_checkpointing', 'is_parallelizable', 'language_model', 'load_adapter', 'load_state_dict', 'loss_function', 'loss_type', 'main_input_name', 'model_tags', 'modules', 'mtia', 'name_or_path', 'named_buffers', 'named_children', 'named_modules', 'named_parameters', 'num_parameters', 'parameters', 'post_init', 'prune_heads', 'push_to_hub', 'register_backward_hook', 'register_buffer', 'register_for_auto_class', 'register_forward_hook', 'register_forward_pre_hook', 'register_full_backward_hook', 'register_full_backward_pre_hook', 'register_load_state_dict_post_hook', 'register_load_state_dict_pre_hook', 'register_module', 'register_parameter', 'register_state_dict_post_hook', 'register_state_dict_pre_hook', 'requires_grad_', 'reset_memory_hooks_state', 'resize_position_embeddings', 'resize_token_embeddings', 'retrieve_modules_from_names', 'reverse_bettertransformer', 'rope_deltas', 'save_pretrained', 'set_adapter', 'set_attn_implementation', 'set_decoder', 'set_extra_state', 'set_input_embeddings', 'set_output_embeddings', 'set_submodule', 'share_memory', 'smart_apply', 'state_dict', 'supports_gradient_checkpointing', 'supports_pp_plan', 'supports_tp_plan', 'tie_weights', 'to', 'to', 'to_bettertransformer', 'to_empty', 'tp_size', 'train', 'training', 'type', 'visual', 'warn_if_padding_and_no_attention_mask', 'warnings_issued', 'xpu', 'zero_grad'] +ERROR:multimodal_retrieval_local:encode_image: 处理图像时出错: embedding(): argument 'indices' (position 2) must be Tensor, not NoneType +ERROR:multimodal_retrieval_local:add_images: 图像编码失败,返回空数组 +INFO:multimodal_retrieval_local:索引保存成功: /root/mmeb/local_faiss_index.index +INFO:multimodal_retrieval_local:元数据保存成功: /root/mmeb/local_faiss_index_metadata.json +INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 05:52:59] "POST /api/add_image HTTP/1.1" 200 - +INFO:multimodal_retrieval_local:索引保存成功: /root/mmeb/local_faiss_index.index +INFO:multimodal_retrieval_local:元数据保存成功: /root/mmeb/local_faiss_index_metadata.json +INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 05:53:00] "POST /api/save_index HTTP/1.1" 200 - +INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 05:53:01] "GET /api/system_info HTTP/1.1" 200 - diff --git a/baidu_vdb_backend.py b/baidu_vdb_backend.py index f17975e..d293724 100644 --- a/baidu_vdb_backend.py +++ b/baidu_vdb_backend.py @@ -118,30 +118,29 @@ class BaiduVDBBackend: try: logger.info(f"创建文本向量表: {self.text_table_name}") - # 定义字段 - 使用最简单的配置 + # 定义字段 - 移除可能导致问题的复杂配置 fields = [ - Field("id", FieldType.STRING, primary_key=True, partition_key=True, not_null=True), + Field("id", FieldType.STRING, primary_key=True, not_null=True), Field("text_content", FieldType.STRING, not_null=True), Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension) ] - # 定义索引 + # 定义索引 - 简化配置 indexes = [ VectorIndex( index_name="text_vector_idx", index_type=IndexType.HNSW, field="vector", metric_type=MetricType.COSINE, - params=HNSWParams(m=32, efconstruction=200), + params=HNSWParams(m=16, efconstruction=100), auto_build=True ) ] - # 创建表 + # 创建表 - 简化配置 self.text_table = self.db.create_table( table_name=self.text_table_name, - replication=2, # 双副本 - partition=Partition(partition_num=3), # 3个分区 + replication=1, # 单副本 schema=Schema(fields=fields, indexes=indexes) ) @@ -156,30 +155,29 @@ class BaiduVDBBackend: try: logger.info(f"创建图像向量表: {self.image_table_name}") - # 定义字段 - 使用最简单的配置 + # 定义字段 - 移除可能导致问题的复杂配置 fields = [ - Field("id", FieldType.STRING, primary_key=True, partition_key=True, not_null=True), + Field("id", FieldType.STRING, primary_key=True, not_null=True), Field("image_path", FieldType.STRING, not_null=True), Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension) ] - # 定义索引 + # 定义索引 - 简化配置 indexes = [ VectorIndex( index_name="image_vector_idx", index_type=IndexType.HNSW, field="vector", metric_type=MetricType.COSINE, - params=HNSWParams(m=32, efconstruction=200), + params=HNSWParams(m=16, efconstruction=100), auto_build=True ) ] - # 创建表 + # 创建表 - 简化配置 self.image_table = self.db.create_table( table_name=self.image_table_name, - replication=2, # 双副本 - partition=Partition(partition_num=3), # 3个分区 + replication=1, # 单副本 schema=Schema(fields=fields, indexes=indexes) ) diff --git a/faiss_index_local.index b/faiss_index_local.index new file mode 100644 index 0000000..27dba4e Binary files /dev/null and b/faiss_index_local.index differ diff --git a/faiss_index_local_metadata.json b/faiss_index_local_metadata.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/faiss_index_local_metadata.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/faiss_index_test.index b/faiss_index_test.index new file mode 100644 index 0000000..27dba4e Binary files /dev/null and b/faiss_index_test.index differ diff --git a/faiss_index_test_metadata.json b/faiss_index_test_metadata.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/faiss_index_test_metadata.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/faiss_vector_store.py b/faiss_vector_store.py new file mode 100644 index 0000000..cacfcf0 --- /dev/null +++ b/faiss_vector_store.py @@ -0,0 +1,147 @@ +import os +import json +import numpy as np +import faiss +from typing import List, Dict, Any, Optional, Tuple +import logging + +class FaissVectorStore: + def __init__(self, index_path: str = "faiss_index", dimension: int = 3584): + """ + 初始化FAISS向量存储 + + 参数: + index_path: 索引文件路径 + dimension: 向量维度 + """ + self.index_path = index_path + self.dimension = dimension + self.index = None + self.metadata = {} + self.metadata_path = f"{index_path}_metadata.json" + + # 加载现有索引或创建新索引 + self._load_or_create_index() + + def _load_or_create_index(self): + """加载现有索引或创建新索引""" + if os.path.exists(f"{self.index_path}.index"): + logging.info(f"加载现有索引: {self.index_path}") + self.index = faiss.read_index(f"{self.index_path}.index") + self._load_metadata() + else: + logging.info(f"创建新索引,维度: {self.dimension}") + self.index = faiss.IndexFlatL2(self.dimension) # 使用L2距离 + + def _load_metadata(self): + """加载元数据""" + if os.path.exists(self.metadata_path): + with open(self.metadata_path, 'r', encoding='utf-8') as f: + self.metadata = json.load(f) + + def _save_metadata(self): + """保存元数据到文件""" + with open(self.metadata_path, 'w', encoding='utf-8') as f: + json.dump(self.metadata, f, ensure_ascii=False, indent=2) + + def save_index(self): + """保存索引和元数据""" + if self.index is not None: + faiss.write_index(self.index, f"{self.index_path}.index") + self._save_metadata() + logging.info(f"索引已保存到 {self.index_path}.index") + + def add_vectors( + self, + vectors: np.ndarray, + metadatas: List[Dict[str, Any]] + ) -> List[str]: + """ + 添加向量和元数据 + + 参数: + vectors: 向量数组 + metadatas: 对应的元数据列表 + + 返回: + 添加的向量ID列表 + """ + if len(vectors) != len(metadatas): + raise ValueError("vectors和metadatas长度必须相同") + + start_id = len(self.metadata) + ids = list(range(start_id, start_id + len(vectors))) + + # 添加向量到索引 + self.index.add(vectors.astype('float32')) + + # 保存元数据 + for idx, vector_id in enumerate(ids): + self.metadata[str(vector_id)] = metadatas[idx] + + # 保存索引和元数据 + self.save_index() + + return [str(id) for id in ids] + + def search( + self, + query_vector: np.ndarray, + k: int = 5 + ) -> Tuple[List[Dict[str, Any]], List[float]]: + """ + 相似性搜索 + + 参数: + query_vector: 查询向量 + k: 返回结果数量 + + 返回: + (结果列表, 距离列表) + """ + if self.index is None: + return [], [] + + # 确保输入是2D数组 + if len(query_vector.shape) == 1: + query_vector = query_vector.reshape(1, -1) + + # 执行搜索 + distances, indices = self.index.search(query_vector.astype('float32'), k) + + # 处理结果 + results = [] + for i in range(len(indices[0])): + idx = indices[0][i] + if idx < 0: # FAISS可能返回-1表示无效索引 + continue + + vector_id = str(idx) + if vector_id in self.metadata: + result = self.metadata[vector_id].copy() + result['distance'] = float(distances[0][i]) + results.append(result) + + return results, distances[0].tolist() + + def get_vector_count(self) -> int: + """获取向量数量""" + return self.index.ntotal if self.index is not None else 0 + + def delete_vectors(self, vector_ids: List[str]) -> bool: + """ + 删除指定ID的向量 + + 注意: FAISS不支持直接删除向量,这里实现为逻辑删除 + """ + deleted_count = 0 + for vector_id in vector_ids: + if vector_id in self.metadata: + del self.metadata[vector_id] + deleted_count += 1 + + if deleted_count > 0: + self._save_metadata() + logging.warning("FAISS不支持直接删除向量,已从元数据中移除,但索引中仍保留") + + return deleted_count > 0 diff --git a/local_faiss_index.index b/local_faiss_index.index new file mode 100644 index 0000000..27dba4e Binary files /dev/null and b/local_faiss_index.index differ diff --git a/local_faiss_index_metadata.json b/local_faiss_index_metadata.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/local_faiss_index_metadata.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/local_file_handler.py b/local_file_handler.py new file mode 100644 index 0000000..00039c2 --- /dev/null +++ b/local_file_handler.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +本地文件处理器 +简化版的文件处理器,不依赖外部服务 +""" + +import os +import io +import tempfile +import logging +from contextlib import contextmanager +from typing import Dict, List, Optional, Any, Union, BinaryIO +from pathlib import Path + +logger = logging.getLogger(__name__) + +class LocalFileHandler: + """本地文件处理器""" + + # 小文件阈值 (5MB) + SMALL_FILE_THRESHOLD = 5 * 1024 * 1024 + + # 支持的图像格式 + SUPPORTED_IMAGE_FORMATS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'} + + def __init__(self, temp_dir: str = None): + """ + 初始化本地文件处理器 + + Args: + temp_dir: 临时文件目录 + """ + self.temp_dir = temp_dir or tempfile.gettempdir() + self.temp_files = set() # 跟踪临时文件 + + # 确保临时目录存在 + os.makedirs(self.temp_dir, exist_ok=True) + + @contextmanager + def temp_file_context(self, content: bytes = None, suffix: str = None, delete_on_exit: bool = True): + """临时文件上下文管理器,确保自动清理""" + temp_fd, temp_path = tempfile.mkstemp(suffix=suffix, dir=self.temp_dir) + self.temp_files.add(temp_path) + + try: + os.close(temp_fd) # 关闭文件描述符 + + # 如果提供了内容,写入文件 + if content is not None: + with open(temp_path, 'wb') as f: + f.write(content) + + yield temp_path + finally: + if delete_on_exit and os.path.exists(temp_path): + try: + os.unlink(temp_path) + self.temp_files.discard(temp_path) + logger.debug(f"🗑️ 临时文件已清理: {temp_path}") + except Exception as e: + logger.warning(f"⚠️ 临时文件清理失败: {temp_path}, {e}") + + def cleanup_all_temp_files(self): + """清理所有跟踪的临时文件""" + for temp_path in list(self.temp_files): + if os.path.exists(temp_path): + try: + os.unlink(temp_path) + logger.debug(f"🗑️ 清理临时文件: {temp_path}") + except Exception as e: + logger.warning(f"⚠️ 清理临时文件失败: {temp_path}, {e}") + self.temp_files.clear() + + def get_file_size(self, file_obj) -> int: + """获取文件大小""" + if hasattr(file_obj, 'content_length') and file_obj.content_length: + return file_obj.content_length + + # 通过读取内容获取大小 + current_pos = file_obj.tell() + file_obj.seek(0, 2) # 移动到文件末尾 + size = file_obj.tell() + file_obj.seek(current_pos) # 恢复原位置 + return size + + def is_small_file(self, file_obj) -> bool: + """判断是否为小文件""" + return self.get_file_size(file_obj) <= self.SMALL_FILE_THRESHOLD + + def get_temp_file_for_model(self, file_obj, filename: str) -> Optional[str]: + """为模型处理获取临时文件路径(确保文件存在于本地)""" + try: + ext = os.path.splitext(filename)[1].lower() + + # 创建临时文件(不自动删除,供模型使用) + temp_fd, temp_path = tempfile.mkstemp(suffix=ext, dir=self.temp_dir) + self.temp_files.add(temp_path) + + try: + # 写入文件内容 + file_obj.seek(0) + with os.fdopen(temp_fd, 'wb') as temp_file: + temp_file.write(file_obj.read()) + + logger.debug(f"📁 为模型创建临时文件: {temp_path}") + return temp_path + + except Exception as e: + os.close(temp_fd) + raise e + + except Exception as e: + logger.error(f"❌ 为模型创建临时文件失败: {filename}, {e}") + return None + + def cleanup_temp_file(self, temp_path: str): + """清理指定的临时文件""" + if temp_path and os.path.exists(temp_path): + try: + os.unlink(temp_path) + self.temp_files.discard(temp_path) + logger.debug(f"🗑️ 清理临时文件: {temp_path}") + except Exception as e: + logger.warning(f"⚠️ 清理临时文件失败: {temp_path}, {e}") + +# 全局实例 +file_handler = None + +def get_file_handler(temp_dir: str = None) -> LocalFileHandler: + """获取文件处理器实例""" + global file_handler + if file_handler is None: + file_handler = LocalFileHandler(temp_dir=temp_dir) + return file_handler diff --git a/model_download_guide.md b/model_download_guide.md new file mode 100644 index 0000000..f2a2f20 --- /dev/null +++ b/model_download_guide.md @@ -0,0 +1,108 @@ +# 多模态模型下载指南 + +## 下载 OpenSearch-AI/Ops-MM-embedding-v1-7B 模型 + +### 方法1:使用 git-lfs + +```bash +# 安装 git-lfs +apt-get install git-lfs +# 或 +curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash +apt-get install git-lfs + +# 初始化 git-lfs +git lfs install + +# 克隆模型仓库 +mkdir -p ~/models +git clone https://huggingface.co/OpenSearch-AI/Ops-MM-embedding-v1-7B ~/models/Ops-MM-embedding-v1-7B +``` + +### 方法2:使用 huggingface-cli + +```bash +# 安装 huggingface-hub +pip install huggingface-hub + +# 下载模型 +mkdir -p ~/models +huggingface-cli download OpenSearch-AI/Ops-MM-embedding-v1-7B --local-dir ~/models/Ops-MM-embedding-v1-7B +``` + +### 方法3:手动下载关键文件 + +如果上述方法不可行,可以手动下载以下关键文件: + +1. 访问 https://huggingface.co/OpenSearch-AI/Ops-MM-embedding-v1-7B/tree/main +2. 下载以下文件: + - `config.json` + - `pytorch_model.bin` (或分片文件 `pytorch_model-00001-of-00002.bin` 等) + - `tokenizer.json` + - `tokenizer_config.json` + - `special_tokens_map.json` + - `vocab.txt` + +## 下载替代轻量级模型 + +如果主模型太大,可以下载这些较小的替代模型: + +### CLIP 模型 + +```bash +mkdir -p ~/models/clip-ViT-B-32 +huggingface-cli download openai/clip-vit-base-patch32 --local-dir ~/models/clip-ViT-B-32 +``` + +### 多语言CLIP模型 + +```bash +mkdir -p ~/models/clip-multilingual +huggingface-cli download sentence-transformers/clip-ViT-B-32-multilingual-v1 --local-dir ~/models/clip-multilingual +``` + +## 传输模型文件 + +下载完成后,使用以下方法将模型传输到目标服务器: + +### 使用 scp + +```bash +# 从当前机器传输到目标服务器 +scp -r ~/models/Ops-MM-embedding-v1-7B user@target-server:/root/models/ +``` + +### 使用压缩文件 + +```bash +# 压缩 +tar -czvf model.tar.gz ~/models/Ops-MM-embedding-v1-7B + +# 传输压缩文件 +scp model.tar.gz user@target-server:/root/ + +# 在目标服务器上解压 +ssh user@target-server +mkdir -p /root/models +tar -xzvf /root/model.tar.gz -C /root/models +``` + +## 验证模型文件 + +模型下载完成后,目录结构应类似于: + +``` +/root/models/Ops-MM-embedding-v1-7B/ +├── config.json +├── pytorch_model.bin (或分片文件) +├── tokenizer.json +├── tokenizer_config.json +├── special_tokens_map.json +└── vocab.txt +``` + +使用以下命令验证文件完整性: + +```bash +ls -la /root/models/Ops-MM-embedding-v1-7B/ +``` diff --git a/multimodal_retrieval_faiss.py b/multimodal_retrieval_faiss.py new file mode 100644 index 0000000..f5949bc --- /dev/null +++ b/multimodal_retrieval_faiss.py @@ -0,0 +1,370 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +基于FAISS的多模态检索系统 +支持文搜文、文搜图、图搜文、图搜图四种检索模式 +""" + +import torch +import torch.nn as nn +from torch.nn.parallel import DataParallel +import numpy as np +from PIL import Image +from transformers import AutoModel, AutoProcessor, AutoTokenizer +from typing import List, Union, Tuple, Dict, Any, Optional +import os +import json +from pathlib import Path +import logging +import gc +from concurrent.futures import ThreadPoolExecutor, as_completed +import threading + +from faiss_vector_store import FaissVectorStore + +# 设置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class MultimodalRetrievalFAISS: + """基于FAISS的多模态检索系统""" + + def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B", + use_all_gpus: bool = True, gpu_ids: List[int] = None, + min_memory_gb: int = 12, index_path: str = "faiss_index"): + """ + 初始化多模态检索系统 + + Args: + model_name: 模型名称 + use_all_gpus: 是否使用所有可用GPU + gpu_ids: 指定使用的GPU ID列表 + min_memory_gb: 最小可用内存(GB) + index_path: FAISS索引文件路径 + """ + self.model_name = model_name + self.index_path = index_path + + # 设置GPU设备 + self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb) + + # 清理GPU内存 + self._clear_all_gpu_memory() + + # 加载模型和处理器 + self._load_model_and_processor() + + # 初始化FAISS向量存储 + self.vector_store = FaissVectorStore( + index_path=index_path, + dimension=3584 # OpenSearch-AI/Ops-MM-embedding-v1-7B的向量维度 + ) + + logger.info(f"多模态检索系统初始化完成,使用模型: {model_name}") + logger.info(f"向量存储路径: {index_path}") + + def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb: int): + """设置GPU设备""" + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.use_gpu = self.device.type == "cuda" + + if self.use_gpu: + self.available_gpus = self._get_available_gpus(min_memory_gb) + + if not self.available_gpus: + logger.warning(f"没有可用的GPU或GPU内存不足{min_memory_gb}GB,将使用CPU") + self.device = torch.device("cpu") + self.use_gpu = False + else: + if gpu_ids: + self.gpu_ids = [gid for gid in gpu_ids if gid in self.available_gpus] + if not self.gpu_ids: + logger.warning(f"指定的GPU {gpu_ids}不可用或内存不足,将使用可用的GPU: {self.available_gpus}") + self.gpu_ids = self.available_gpus + elif use_all_gpus: + self.gpu_ids = self.available_gpus + else: + self.gpu_ids = [self.available_gpus[0]] + + logger.info(f"使用GPU: {self.gpu_ids}") + self.device = torch.device(f"cuda:{self.gpu_ids[0]}") + + def _get_available_gpus(self, min_memory_gb: int) -> List[int]: + """获取可用的GPU列表""" + available_gpus = [] + for i in range(torch.cuda.device_count()): + total_mem = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3) # GB + if total_mem >= min_memory_gb: + available_gpus.append(i) + return available_gpus + + def _clear_all_gpu_memory(self): + """清理GPU内存""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + def _load_model_and_processor(self): + """加载模型和处理器""" + logger.info(f"加载模型和处理器: {self.model_name}") + + # 加载tokenizer和processor + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self.processor = AutoProcessor.from_pretrained(self.model_name) + + # 加载模型 + self.model = AutoModel.from_pretrained( + self.model_name, + torch_dtype=torch.float16 if self.use_gpu else torch.float32, + device_map="auto" if len(self.gpu_ids) > 1 else None + ) + + # 如果使用多GPU,包装模型 + if len(self.gpu_ids) > 1: + self.model = DataParallel(self.model, device_ids=self.gpu_ids) + + self.model.eval() + self.model.to(self.device) + + logger.info("模型和处理器加载完成") + + def encode_text(self, text: Union[str, List[str]]) -> np.ndarray: + """编码文本为向量""" + if isinstance(text, str): + text = [text] + + inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self.model(**inputs) + # 获取[CLS]标记的隐藏状态作为句子表示 + text_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy() + + # 归一化向量 + text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True) + return text_embeddings[0] if len(text) == 1 else text_embeddings + + def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray: + """编码图像为向量""" + if isinstance(image, Image.Image): + image = [image] + + inputs = self.processor(images=image, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self.model.vision_model(**inputs) + # 获取[CLS]标记的隐藏状态作为图像表示 + image_embeddings = outputs.pooler_output.cpu().numpy() + + # 归一化向量 + image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True) + return image_embeddings[0] if len(image) == 1 else image_embeddings + + def add_texts( + self, + texts: List[str], + metadatas: Optional[List[Dict[str, Any]]] = None + ) -> List[str]: + """ + 添加文本到检索系统 + + Args: + texts: 文本列表 + metadatas: 元数据列表,每个元素是一个字典 + + Returns: + 添加的文本ID列表 + """ + if not texts: + return [] + + if metadatas is None: + metadatas = [{} for _ in range(len(texts))] + + if len(texts) != len(metadatas): + raise ValueError("texts和metadatas长度必须相同") + + # 编码文本 + text_embeddings = self.encode_text(texts) + + # 准备元数据 + for i, text in enumerate(texts): + metadatas[i].update({ + "text": text, + "type": "text" + }) + + # 添加到向量存储 + vector_ids = self.vector_store.add_vectors(text_embeddings, metadatas) + + logger.info(f"成功添加{len(vector_ids)}条文本到检索系统") + return vector_ids + + def add_images( + self, + images: List[Image.Image], + metadatas: Optional[List[Dict[str, Any]]] = None + ) -> List[str]: + """ + 添加图像到检索系统 + + Args: + images: PIL图像列表 + metadatas: 元数据列表,每个元素是一个字典 + + Returns: + 添加的图像ID列表 + """ + if not images: + return [] + + if metadatas is None: + metadatas = [{} for _ in range(len(images))] + + if len(images) != len(metadatas): + raise ValueError("images和metadatas长度必须相同") + + # 编码图像 + image_embeddings = self.encode_image(images) + + # 准备元数据 + for i, image in enumerate(images): + metadatas[i].update({ + "type": "image", + "width": image.width, + "height": image.height + }) + + # 添加到向量存储 + vector_ids = self.vector_store.add_vectors(image_embeddings, metadatas) + + logger.info(f"成功添加{len(vector_ids)}张图像到检索系统") + return vector_ids + + def search_by_text( + self, + query: str, + k: int = 5, + filter_condition: Optional[Dict[str, Any]] = None + ) -> List[Dict[str, Any]]: + """ + 文本搜索 + + Args: + query: 查询文本 + k: 返回结果数量 + filter_condition: 过滤条件 + + Returns: + 搜索结果列表,每个元素包含相似项和分数 + """ + # 编码查询文本 + query_embedding = self.encode_text(query) + + # 执行搜索 + results, distances = self.vector_store.search(query_embedding, k) + + # 处理结果 + search_results = [] + for i, (result, distance) in enumerate(zip(results, distances)): + result["score"] = 1.0 / (1.0 + distance) # 将距离转换为相似度分数 + search_results.append(result) + + return search_results + + def search_by_image( + self, + image: Image.Image, + k: int = 5, + filter_condition: Optional[Dict[str, Any]] = None + ) -> List[Dict[str, Any]]: + """ + 图像搜索 + + Args: + image: 查询图像 + k: 返回结果数量 + filter_condition: 过滤条件 + + Returns: + 搜索结果列表,每个元素包含相似项和分数 + """ + # 编码查询图像 + query_embedding = self.encode_image(image) + + # 执行搜索 + results, distances = self.vector_store.search(query_embedding, k) + + # 处理结果 + search_results = [] + for i, (result, distance) in enumerate(zip(results, distances)): + result["score"] = 1.0 / (1.0 + distance) # 将距离转换为相似度分数 + search_results.append(result) + + return search_results + + def get_vector_count(self) -> int: + """获取向量数量""" + return self.vector_store.get_vector_count() + + def save_index(self): + """保存索引""" + self.vector_store.save_index() + logger.info("索引已保存") + + def __del__(self): + """析构函数,确保资源被正确释放""" + if hasattr(self, 'model'): + del self.model + self._clear_all_gpu_memory() + if hasattr(self, 'vector_store'): + self.save_index() + + +def test_faiss_system(): + """测试FAISS多模态检索系统""" + import time + from PIL import Image + import numpy as np + + # 初始化检索系统 + print("初始化多模态检索系统...") + retrieval = MultimodalRetrievalFAISS( + model_name="OpenSearch-AI/Ops-MM-embedding-v1-7B", + use_all_gpus=True, + index_path="faiss_index_test" + ) + + # 测试文本 + texts = [ + "一只可爱的橘色猫咪在沙发上睡觉", + "城市夜景中的高楼大厦和车流", + "阳光明媚的海滩上,人们在冲浪和晒太阳", + "美味的意大利面配红酒和沙拉", + "雪山上滑雪的运动员" + ] + + # 添加文本 + print("\n添加文本到检索系统...") + text_ids = retrieval.add_texts(texts) + print(f"添加了{len(text_ids)}条文本") + + # 测试文本搜索 + print("\n测试文本搜索...") + query_text = "一只猫在睡觉" + print(f"查询: {query_text}") + results = retrieval.search_by_text(query_text, k=2) + for i, result in enumerate(results): + print(f"结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})") + + # 测试保存和加载 + print("\n保存索引...") + retrieval.save_index() + + print("\n测试完成!") + + +if __name__ == "__main__": + test_faiss_system() diff --git a/multimodal_retrieval_local.py b/multimodal_retrieval_local.py new file mode 100644 index 0000000..8ff9208 --- /dev/null +++ b/multimodal_retrieval_local.py @@ -0,0 +1,607 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +使用本地模型的多模态检索系统 +支持文搜文、文搜图、图搜文、图搜图四种检索模式 +""" + +import torch +import numpy as np +from PIL import Image +from transformers import AutoModel, AutoProcessor, AutoTokenizer +from typing import List, Union, Tuple, Dict, Any, Optional +import os +import json +from pathlib import Path +import logging +import gc +import faiss +import time + +# 设置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# 设置离线模式 +os.environ['TRANSFORMERS_OFFLINE'] = '1' + +class MultimodalRetrievalLocal: + """使用本地模型的多模态检索系统""" + + def __init__(self, + model_path: str = "/root/models/Ops-MM-embedding-v1-7B", + use_all_gpus: bool = True, + gpu_ids: List[int] = None, + min_memory_gb: int = 12, + index_path: str = "local_faiss_index"): + """ + 初始化多模态检索系统 + + Args: + model_path: 本地模型路径 + use_all_gpus: 是否使用所有可用GPU + gpu_ids: 指定使用的GPU ID列表 + min_memory_gb: 最小可用内存(GB) + index_path: FAISS索引文件路径 + """ + self.model_path = model_path + self.index_path = index_path + + # 检查模型路径 + if not os.path.exists(model_path): + logger.error(f"模型路径不存在: {model_path}") + logger.info("请先下载模型到指定路径") + raise FileNotFoundError(f"模型路径不存在: {model_path}") + + # 设置GPU设备 + self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb) + + # 清理GPU内存 + self._clear_all_gpu_memory() + + # 加载模型和处理器 + self._load_model_and_processor() + + # 初始化FAISS索引 + self._init_index() + + logger.info(f"多模态检索系统初始化完成,使用本地模型: {model_path}") + logger.info(f"向量存储路径: {index_path}") + + def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb: int): + """设置GPU设备""" + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.use_gpu = self.device.type == "cuda" + + if self.use_gpu: + self.available_gpus = self._get_available_gpus(min_memory_gb) + + if not self.available_gpus: + logger.warning(f"没有可用的GPU或GPU内存不足{min_memory_gb}GB,将使用CPU") + self.device = torch.device("cpu") + self.use_gpu = False + else: + if gpu_ids: + self.gpu_ids = [gid for gid in gpu_ids if gid in self.available_gpus] + if not self.gpu_ids: + logger.warning(f"指定的GPU {gpu_ids}不可用或内存不足,将使用可用的GPU: {self.available_gpus}") + self.gpu_ids = self.available_gpus + elif use_all_gpus: + self.gpu_ids = self.available_gpus + else: + self.gpu_ids = [self.available_gpus[0]] + + logger.info(f"使用GPU: {self.gpu_ids}") + self.device = torch.device(f"cuda:{self.gpu_ids[0]}") + else: + logger.warning("没有可用的GPU,将使用CPU") + self.gpu_ids = [] + + def _get_available_gpus(self, min_memory_gb: int) -> List[int]: + """获取可用的GPU列表""" + available_gpus = [] + for i in range(torch.cuda.device_count()): + total_mem = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3) # GB + if total_mem >= min_memory_gb: + available_gpus.append(i) + return available_gpus + + def _clear_all_gpu_memory(self): + """清理GPU内存""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + def _load_model_and_processor(self): + """加载模型和处理器""" + logger.info(f"加载本地模型和处理器: {self.model_path}") + + try: + # 加载模型和处理器 + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) + self.processor = AutoProcessor.from_pretrained(self.model_path) + + # 输出处理器信息 + logger.info(f"Processor类型: {type(self.processor)}") + logger.info(f"Processor方法: {dir(self.processor)}") + + # 检查是否有图像处理器 + if hasattr(self.processor, 'image_processor'): + logger.info(f"Image processor类型: {type(self.processor.image_processor)}") + logger.info(f"Image processor方法: {dir(self.processor.image_processor)}") + + # 加载模型 + self.model = AutoModel.from_pretrained( + self.model_path, + torch_dtype=torch.float16 if self.use_gpu else torch.float32, + device_map="auto" if len(self.gpu_ids) > 1 else None + ) + + if len(self.gpu_ids) == 1: + self.model.to(self.device) + + self.model.eval() + + # 获取向量维度 + self.vector_dim = self.model.config.hidden_size + logger.info(f"向量维度: {self.vector_dim}") + + logger.info("模型和处理器加载成功") + + except Exception as e: + logger.error(f"模型加载失败: {str(e)}") + raise RuntimeError(f"模型加载失败: {str(e)}") + + def _init_index(self): + """初始化FAISS索引""" + index_file = f"{self.index_path}.index" + if os.path.exists(index_file): + logger.info(f"加载现有索引: {index_file}") + try: + self.index = faiss.read_index(index_file) + logger.info(f"索引加载成功,包含{self.index.ntotal}个向量") + except Exception as e: + logger.error(f"索引加载失败: {str(e)}") + logger.info("创建新索引...") + self.index = faiss.IndexFlatL2(self.vector_dim) + else: + logger.info(f"创建新索引,维度: {self.vector_dim}") + self.index = faiss.IndexFlatL2(self.vector_dim) + + # 加载元数据 + self.metadata = {} + metadata_file = f"{self.index_path}_metadata.json" + if os.path.exists(metadata_file): + try: + with open(metadata_file, 'r', encoding='utf-8') as f: + self.metadata = json.load(f) + logger.info(f"元数据加载成功,包含{len(self.metadata)}条记录") + except Exception as e: + logger.error(f"元数据加载失败: {str(e)}") + + def encode_text(self, text: Union[str, List[str]]) -> np.ndarray: + """编码文本为向量""" + if isinstance(text, str): + text = [text] + + inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self.model(**inputs) + # 获取[CLS]标记的隐藏状态作为句子表示 + text_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy() + + # 归一化向量 + text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True) + return text_embeddings[0] if len(text) == 1 else text_embeddings + + def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray: + """编码图像为向量""" + try: + logger.info(f"encode_image: 开始编码图像,类型: {type(image)}") + + if isinstance(image, Image.Image): + logger.info(f"encode_image: 单个图像,大小: {image.size}") + image = [image] + else: + logger.info(f"encode_image: 图像列表,长度: {len(image)}") + + # 检查图像是否为空 + if not image or len(image) == 0: + logger.error("encode_image: 图像列表为空") + # 返回一个空的二维数组 + return np.zeros((0, self.vector_dim)) + + # 检查图像是否有效 + for i, img in enumerate(image): + if not isinstance(img, Image.Image): + logger.error(f"encode_image: 第{i}个元素不是有效的PIL图像,类型: {type(img)}") + # 返回一个空的二维数组 + return np.zeros((0, self.vector_dim)) + + logger.info("encode_image: 处理图像输入") + + # 检查图像格式 + for i, img in enumerate(image): + logger.info(f"encode_image: 图像 {i} 格式: {img.format}, 模式: {img.mode}, 大小: {img.size}") + # 转换为RGB模式,如果不是 + if img.mode != 'RGB': + logger.info(f"encode_image: 将图像 {i} 从 {img.mode} 转换为 RGB") + image[i] = img.convert('RGB') + + try: + # 直接使用image_processor处理图像 + if hasattr(self.processor, 'image_processor'): + logger.info("encode_image: 使用image_processor处理图像") + pixel_values = self.processor.image_processor(images=image, return_tensors="pt").pixel_values + inputs = {"pixel_values": pixel_values} + else: + logger.info("encode_image: 使用processor处理图像") + inputs = self.processor(images=image, return_tensors="pt") + + if not inputs or len(inputs) == 0: + logger.error("encode_image: processor返回了空的输入") + return np.zeros((0, self.vector_dim)) + + logger.info(f"encode_image: 处理后的输入键: {list(inputs.keys())}") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + logger.info("encode_image: 运行模型推理") + logger.info(f"Model类型: {type(self.model)}") + logger.info(f"Model属性: {dir(self.model)}") + + # 检查模型结构 + try: + logger.info(f"Model配置: {self.model.config}") + logger.info(f"Model配置属性: {dir(self.model.config)}") + else: + visual_outputs = self.model.visual(**inputs) + + if hasattr(visual_outputs, 'pooler_output'): + image_embeddings = visual_outputs.pooler_output.cpu().numpy() + elif hasattr(visual_outputs, 'last_hidden_state'): + image_embeddings = visual_outputs.last_hidden_state[:, 0, :].cpu().numpy() + else: + logger.error("encode_image: 无法从视觉模型输出中获取图像向量") + raise ValueError("无法从视觉模型输出中获取图像向量") + else: + # 尝试直接使用模型进行推理 + logger.info("encode_image: 尝试直接使用模型进行推理") + with torch.no_grad(): + # 使用空文本输入,只提供图像 + if 'pixel_values' in inputs: + outputs = self.model(pixel_values=inputs['pixel_values'], input_ids=None) + else: + outputs = self.model(**inputs, input_ids=None) + + # 尝试从输出中获取图像向量 + if hasattr(outputs, 'image_embeds'): + image_embeddings = outputs.image_embeds.cpu().numpy() + elif hasattr(outputs, 'vision_model_output') and hasattr(outputs.vision_model_output, 'pooler_output'): + image_embeddings = outputs.vision_model_output.pooler_output.cpu().numpy() + elif hasattr(outputs, 'pooler_output'): + image_embeddings = outputs.pooler_output.cpu().numpy() + elif hasattr(outputs, 'last_hidden_state'): + image_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy() + else: + logger.error("encode_image: 无法从模型输出中获取图像向量") + raise ValueError("无法从模型输出中获取图像向量") + except Exception as e: + logger.error(f"encode_image: 处理图像时出错: {str(e)}") + raise e + return np.zeros((0, self.vector_dim)) + + # 归一化向量 + image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True) + + # 始终返回二维数组,即使只有一个图像 + if len(image) == 1: + result = np.array([image_embeddings[0]]) + logger.info(f"encode_image: 返回单个图像向量,形状: {result.shape}") + return result + else: + logger.info(f"encode_image: 返回多个图像向量,形状: {image_embeddings.shape}") + return image_embeddings + + except Exception as e: + logger.error(f"encode_image: 异常: {str(e)}") + # 返回一个空的二维数组 + return np.zeros((0, self.vector_dim)) + + def add_texts( + self, + texts: List[str], + metadatas: Optional[List[Dict[str, Any]]] = None + ) -> List[str]: + """ + 添加文本到检索系统 + + Args: + texts: 文本列表 + metadatas: 元数据列表,每个元素是一个字典 + + Returns: + 添加的文本ID列表 + """ + if not texts: + return [] + + if metadatas is None: + metadatas = [{} for _ in range(len(texts))] + + if len(texts) != len(metadatas): + raise ValueError("texts和metadatas长度必须相同") + + # 编码文本 + text_embeddings = self.encode_text(texts) + + # 准备元数据 + start_id = self.index.ntotal + ids = list(range(start_id, start_id + len(texts))) + + # 添加到索引 + self.index.add(np.array(text_embeddings).astype('float32')) + + # 保存元数据 + for i, id in enumerate(ids): + self.metadata[str(id)] = { + "text": texts[i], + "type": "text", + **metadatas[i] + } + + logger.info(f"成功添加{len(ids)}条文本到检索系统") + return [str(id) for id in ids] + + def add_images( + self, + images: List[Image.Image], + metadatas: Optional[List[Dict[str, Any]]] = None, + image_paths: Optional[List[str]] = None + ) -> List[str]: + """ + 添加图像到检索系统 + + Args: + images: PIL图像列表 + metadatas: 元数据列表,每个元素是一个字典 + image_paths: 图像路径列表,用于保存到元数据 + + Returns: + 添加的图像ID列表 + """ + try: + logger.info(f"add_images: 开始添加图像,数量: {len(images) if images else 0}") + + # 检查图像列表 + if not images or len(images) == 0: + logger.warning("add_images: 图像列表为空") + return [] + + # 准备元数据 + if metadatas is None: + logger.info("add_images: 创建默认元数据") + metadatas = [{} for _ in range(len(images))] + + # 检查长度一致性 + if len(images) != len(metadatas): + logger.error(f"add_images: 长度不一致 - images: {len(images)}, metadatas: {len(metadatas)}") + raise ValueError("images和metadatas长度必须相同") + + # 编码图像 + logger.info("add_images: 编码图像") + image_embeddings = self.encode_image(images) + + # 检查编码结果 + if image_embeddings.shape[0] == 0: + logger.error("add_images: 图像编码失败,返回空数组") + return [] + + # 准备元数据 + start_id = self.index.ntotal + ids = list(range(start_id, start_id + len(images))) + logger.info(f"add_images: 生成索引ID: {start_id} - {start_id + len(images) - 1}") + + # 添加到索引 + logger.info(f"add_images: 添加向量到FAISS索引,形状: {image_embeddings.shape}") + self.index.add(np.array(image_embeddings).astype('float32')) + + # 保存元数据 + for i, id in enumerate(ids): + try: + metadata = { + "type": "image", + "width": images[i].width, + "height": images[i].height, + **metadatas[i] + } + + if image_paths and i < len(image_paths): + metadata["path"] = image_paths[i] + + self.metadata[str(id)] = metadata + logger.debug(f"add_images: 保存元数据成功 - ID: {id}") + except Exception as e: + logger.error(f"add_images: 保存元数据失败 - ID: {id}, 错误: {str(e)}") + + logger.info(f"add_images: 成功添加{len(ids)}张图像到检索系统") + return [str(id) for id in ids] + + except Exception as e: + logger.error(f"add_images: 添加图像异常: {str(e)}") + return [] + + def search_by_text( + self, + query: str, + k: int = 5, + filter_type: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + 文本搜索 + + Args: + query: 查询文本 + k: 返回结果数量 + filter_type: 过滤类型,可选值: "text", "image", None(不过滤) + + Returns: + 搜索结果列表,每个元素包含相似项和分数 + """ + # 编码查询文本 + query_embedding = self.encode_text(query) + + # 执行搜索 + return self._search(query_embedding, k, filter_type) + + def search_by_image( + self, + image: Image.Image, + k: int = 5, + filter_type: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + 图像搜索 + + Args: + image: 查询图像 + k: 返回结果数量 + filter_type: 过滤类型,可选值: "text", "image", None(不过滤) + + Returns: + 搜索结果列表,每个元素包含相似项和分数 + """ + # 编码查询图像 + query_embedding = self.encode_image(image) + + # 执行搜索 + return self._search(query_embedding, k, filter_type) + + def _search( + self, + query_embedding: np.ndarray, + k: int = 5, + filter_type: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + 执行搜索 + + Args: + query_embedding: 查询向量 + k: 返回结果数量 + filter_type: 过滤类型,可选值: "text", "image", None(不过滤) + + Returns: + 搜索结果列表 + """ + if self.index.ntotal == 0: + return [] + + # 确保查询向量是2D数组 + if len(query_embedding.shape) == 1: + query_embedding = query_embedding.reshape(1, -1) + + # 执行搜索,获取更多结果以便过滤 + actual_k = k * 3 if filter_type else k + actual_k = min(actual_k, self.index.ntotal) + distances, indices = self.index.search(query_embedding.astype('float32'), actual_k) + + # 处理结果 + results = [] + for i in range(len(indices[0])): + idx = indices[0][i] + if idx < 0: # FAISS可能返回-1表示无效索引 + continue + + vector_id = str(idx) + if vector_id in self.metadata: + item = self.metadata[vector_id] + + # 如果指定了过滤类型,则只返回该类型的结果 + if filter_type and item.get("type") != filter_type: + continue + + # 添加距离和分数 + result = item.copy() + result["distance"] = float(distances[0][i]) + result["score"] = float(1.0 / (1.0 + distances[0][i])) + results.append(result) + + # 如果已经收集了足够的结果,则停止 + if len(results) >= k: + break + + return results + + def save_index(self): + """保存索引和元数据""" + # 保存索引 + index_file = f"{self.index_path}.index" + try: + faiss.write_index(self.index, index_file) + logger.info(f"索引保存成功: {index_file}") + except Exception as e: + logger.error(f"索引保存失败: {str(e)}") + + # 保存元数据 + metadata_file = f"{self.index_path}_metadata.json" + try: + with open(metadata_file, 'w', encoding='utf-8') as f: + json.dump(self.metadata, f, ensure_ascii=False, indent=2) + logger.info(f"元数据保存成功: {metadata_file}") + except Exception as e: + logger.error(f"元数据保存失败: {str(e)}") + + def get_stats(self) -> Dict[str, Any]: + """获取检索系统统计信息""" + text_count = sum(1 for v in self.metadata.values() if v.get("type") == "text") + image_count = sum(1 for v in self.metadata.values() if v.get("type") == "image") + + return { + "total_vectors": self.index.ntotal, + "text_count": text_count, + "image_count": image_count, + "vector_dimension": self.vector_dim, + "index_path": self.index_path, + "model_path": self.model_path + } + + def clear_index(self): + """清空索引""" + logger.info(f"清空索引: {self.index_path}") + + # 重新创建索引 + self.index = faiss.IndexFlatL2(self.vector_dim) + + # 清空元数据 + self.metadata = {} + + # 保存空索引 + self.save_index() + + logger.info(f"索引已清空: {self.index_path}") + return True + + def list_items(self) -> List[Dict[str, Any]]: + """列出所有索引项""" + items = [] + + for item_id, metadata in self.metadata.items(): + item = metadata.copy() + item['id'] = item_id + items.append(item) + + return items + + def __del__(self): + """析构函数,确保资源被正确释放并自动保存索引""" + try: + if hasattr(self, 'model'): + del self.model + self._clear_all_gpu_memory() + if hasattr(self, 'index') and self.index is not None: + logger.info("系统关闭前自动保存索引") + self.save_index() + except Exception as e: + logger.error(f"析构时保存索引失败: {str(e)}") diff --git a/multimodal_retrieval_vdb.py b/multimodal_retrieval_vdb.py index 0bbbb2b..e3996a0 100644 --- a/multimodal_retrieval_vdb.py +++ b/multimodal_retrieval_vdb.py @@ -60,7 +60,14 @@ class MultimodalRetrievalVDB: "database_name": "multimodal_retrieval" } - self.vdb = BaiduVDBBackend(**vdb_config) + try: + self.vdb = BaiduVDBBackend(**vdb_config) + logger.info("✅ VDB后端初始化成功") + except Exception as e: + logger.error(f"❌ VDB后端初始化失败: {e}") + # 创建一个模拟的VDB后端,避免系统完全崩溃 + self.vdb = None + logger.warning("⚠️ 系统将在无VDB模式下运行,数据将不会持久化") logger.info("多模态检索系统初始化完成") @@ -102,42 +109,102 @@ class MultimodalRetrievalVDB: # 清理GPU内存 self._clear_gpu_memory() - # 加载模型 - if self.num_gpus > 1: - # 多GPU加载 - max_memory = {i: "18GiB" for i in self.device_ids} + # 设置离线模式环境变量 + os.environ['TRANSFORMERS_OFFLINE'] = '1' + os.environ['HF_HUB_OFFLINE'] = '1' + + # 尝试加载模型,如果网络失败则使用本地缓存 + try: + # 加载模型 + if self.num_gpus > 1: + # 多GPU加载 + max_memory = {i: "18GiB" for i in self.device_ids} + + self.model = AutoModel.from_pretrained( + self.model_name, + trust_remote_code=True, + torch_dtype=torch.float16, + device_map="auto", + max_memory=max_memory, + low_cpu_mem_usage=True, + local_files_only=False # 允许从网络下载 + ) + else: + # 单GPU加载 + self.model = AutoModel.from_pretrained( + self.model_name, + trust_remote_code=True, + torch_dtype=torch.float16, + device_map=self.primary_device, + local_files_only=False # 允许从网络下载 + ) - self.model = AutoModel.from_pretrained( - self.model_name, - trust_remote_code=True, - torch_dtype=torch.float16, - device_map="auto", - max_memory=max_memory, - low_cpu_mem_usage=True - ) - else: - # 单GPU加载 - self.model = AutoModel.from_pretrained( - self.model_name, - trust_remote_code=True, - torch_dtype=torch.float16, - device_map=self.primary_device - ) + logger.info("模型从网络加载成功") + + except Exception as network_error: + logger.warning(f"网络加载失败,尝试本地缓存: {network_error}") + + # 尝试从本地缓存加载 + try: + if self.num_gpus > 1: + max_memory = {i: "18GiB" for i in self.device_ids} + + self.model = AutoModel.from_pretrained( + self.model_name, + trust_remote_code=True, + torch_dtype=torch.float16, + device_map="auto", + max_memory=max_memory, + low_cpu_mem_usage=True, + local_files_only=True # 仅使用本地文件 + ) + else: + self.model = AutoModel.from_pretrained( + self.model_name, + trust_remote_code=True, + torch_dtype=torch.float16, + device_map=self.primary_device, + local_files_only=True # 仅使用本地文件 + ) + + logger.info("模型从本地缓存加载成功") + + except Exception as local_error: + logger.error(f"本地缓存加载也失败: {local_error}") + raise local_error # 加载分词器和处理器 - self.tokenizer = AutoTokenizer.from_pretrained( - self.model_name, - trust_remote_code=True - ) + try: + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_name, + trust_remote_code=True, + local_files_only=False + ) + except Exception as e: + logger.warning(f"Tokenizer网络加载失败,尝试本地: {e}") + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_name, + trust_remote_code=True, + local_files_only=True + ) try: self.processor = AutoProcessor.from_pretrained( self.model_name, - trust_remote_code=True + trust_remote_code=True, + local_files_only=False ) except Exception as e: logger.warning(f"Processor加载失败,使用tokenizer: {e}") - self.processor = self.tokenizer + try: + self.processor = AutoProcessor.from_pretrained( + self.model_name, + trust_remote_code=True, + local_files_only=True + ) + except Exception as e2: + logger.warning(f"Processor本地加载也失败,使用tokenizer: {e2}") + self.processor = self.tokenizer logger.info("模型加载完成") return True @@ -274,6 +341,10 @@ class MultimodalRetrievalVDB: Returns: 存储的ID列表 """ + if self.vdb is None: + logger.warning("VDB不可用,文本数据将不会持久化存储") + return [] + logger.info(f"正在存储 {len(texts)} 条文本数据") # 分批处理 @@ -312,6 +383,10 @@ class MultimodalRetrievalVDB: Returns: 存储的ID列表 """ + if self.vdb is None: + logger.warning("VDB不可用,图像数据将不会持久化存储") + return [] + logger.info(f"正在存储 {len(image_paths)} 张图像数据") # 图像处理使用更小的批次 @@ -341,6 +416,10 @@ class MultimodalRetrievalVDB: def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: """文搜文:使用文本查询搜索相似文本""" + if self.vdb is None: + logger.warning("VDB不可用,无法执行搜索") + return [] + logger.info(f"执行文搜文查询: {query}") # 编码查询文本 @@ -358,6 +437,10 @@ class MultimodalRetrievalVDB: def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: """文搜图:使用文本查询搜索相似图像""" + if self.vdb is None: + logger.warning("VDB不可用,无法执行搜索") + return [] + logger.info(f"执行文搜图查询: {query}") # 编码查询文本 @@ -375,6 +458,10 @@ class MultimodalRetrievalVDB: def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]: """图搜文:使用图像查询搜索相似文本""" + if self.vdb is None: + logger.warning("VDB不可用,无法执行搜索") + return [] + logger.info(f"执行图搜文查询") # 编码查询图像 @@ -392,6 +479,10 @@ class MultimodalRetrievalVDB: def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]: """图搜图:使用图像查询搜索相似图像""" + if self.vdb is None: + logger.warning("VDB不可用,无法执行搜索") + return [] + logger.info(f"执行图搜图查询") # 编码查询图像 @@ -426,10 +517,15 @@ class MultimodalRetrievalVDB: def get_statistics(self) -> Dict[str, Any]: """获取系统统计信息""" + if self.vdb is None: + return {"error": "VDB不可用"} return self.vdb.get_statistics() def clear_all_data(self): """清空所有数据""" + if self.vdb is None: + logger.warning("VDB不可用,无法清空数据") + return self.vdb.clear_all_data() def close(self): diff --git a/nohup.out b/nohup.out new file mode 100644 index 0000000..b11bb79 --- /dev/null +++ b/nohup.out @@ -0,0 +1,49 @@ +INFO:baidu_bos_manager:✅ BOS连接测试成功 +INFO:baidu_bos_manager:✅ BOS客户端初始化成功: dmtyz-demo +INFO:mongodb_manager:✅ MongoDB连接成功: mmeb +INFO:mongodb_manager:✅ MongoDB索引创建完成 +INFO:__main__:初始化多模态检索系统... +INFO:multimodal_retrieval_local:使用GPU: [0, 1] +INFO:multimodal_retrieval_local:加载本地模型和处理器: /root/models/Ops-MM-embedding-v1-7B +The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release. +You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0. +INFO:multimodal_retrieval_local:Processor类型: +INFO:multimodal_retrieval_local:Processor方法: ['__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_auto_class', '_check_special_mm_tokens', '_create_repo', '_get_arguments_from_pretrained', '_get_files_timestamps', '_get_num_multimodal_tokens', '_merge_kwargs', '_upload_modified_files', 'apply_chat_template', 'attributes', 'audio_tokenizer', 'batch_decode', 'chat_template', 'check_argument_for_proper_class', 'decode', 'feature_extractor_class', 'from_args_and_dict', 'from_pretrained', 'get_possibly_dynamic_module', 'get_processor_dict', 'image_processor', 'image_processor_class', 'image_token', 'image_token_id', 'model_input_names', 'optional_attributes', 'optional_call_args', 'post_process_image_text_to_text', 'push_to_hub', 'register_for_auto_class', 'save_pretrained', 'to_dict', 'to_json_file', 'to_json_string', 'tokenizer', 'tokenizer_class', 'validate_init_kwargs', 'video_processor', 'video_processor_class', 'video_token', 'video_token_id'] +INFO:multimodal_retrieval_local:Image processor类型: +INFO:multimodal_retrieval_local:Image processor方法: ['__backends', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slotnames__', '__str__', '__subclasshook__', '__weakref__', '_auto_class', '_create_repo', '_further_process_kwargs', '_fuse_mean_std_and_rescale_factor', '_get_files_timestamps', '_prepare_image_like_inputs', '_prepare_images_structure', '_preprocess', '_preprocess_image_like_inputs', '_process_image', '_processor_class', '_set_processor_class', '_upload_modified_files', '_valid_kwargs_names', '_validate_preprocess_kwargs', 'center_crop', 'compile_friendly_resize', 'convert_to_rgb', 'crop_size', 'data_format', 'default_to_square', 'device', 'disable_grouping', 'do_center_crop', 'do_convert_rgb', 'do_normalize', 'do_rescale', 'do_resize', 'fetch_images', 'filter_out_unused_kwargs', 'from_dict', 'from_json_file', 'from_pretrained', 'get_image_processor_dict', 'get_number_of_image_patches', 'image_mean', 'image_processor_type', 'image_std', 'input_data_format', 'max_pixels', 'merge_size', 'min_pixels', 'model_input_names', 'normalize', 'patch_size', 'preprocess', 'push_to_hub', 'register_for_auto_class', 'resample', 'rescale', 'rescale_and_normalize', 'rescale_factor', 'resize', 'return_tensors', 'save_pretrained', 'size', 'temporal_patch_size', 'to_dict', 'to_json_file', 'to_json_string', 'unused_kwargs', 'valid_kwargs'] + Loading checkpoint shards: 0%| | 0/4 [00:00 +INFO:multimodal_retrieval_local:encode_image: 图像列表,长度: 1 +INFO:multimodal_retrieval_local:encode_image: 处理图像输入 +INFO:multimodal_retrieval_local:encode_image: 图像 0 格式: JPEG, 模式: RGB, 大小: (939, 940) +ERROR:multimodal_retrieval_local:encode_image: 处理图像时出错: argument of type 'NoneType' is not iterable +ERROR:multimodal_retrieval_local:add_images: 图像编码失败,返回空数组 +INFO:multimodal_retrieval_local:索引保存成功: /root/mmeb/local_faiss_index.index +INFO:multimodal_retrieval_local:元数据保存成功: /root/mmeb/local_faiss_index_metadata.json +INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 04:02:50] "POST /api/add_image HTTP/1.1" 200 - +INFO:multimodal_retrieval_local:索引保存成功: /root/mmeb/local_faiss_index.index +INFO:multimodal_retrieval_local:元数据保存成功: /root/mmeb/local_faiss_index_metadata.json +INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 04:02:50] "POST /api/save_index HTTP/1.1" 200 - +INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 04:02:51] "GET /api/system_info HTTP/1.1" 200 - diff --git a/optimized_file_handler.py b/optimized_file_handler.py index fed7384..1679669 100644 --- a/optimized_file_handler.py +++ b/optimized_file_handler.py @@ -30,19 +30,30 @@ class OptimizedFileHandler: # 支持的图像格式 SUPPORTED_IMAGE_FORMATS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'} - def __init__(self): + def __init__(self, local_storage_dir=None): self.bos_manager = get_bos_manager() self.mongodb_manager = get_mongodb_manager() self.temp_files = set() # 跟踪临时文件 + self.local_storage_dir = local_storage_dir or tempfile.gettempdir() + + # 确保本地存储目录存在 + if self.local_storage_dir: + os.makedirs(self.local_storage_dir, exist_ok=True) @contextmanager - def temp_file_context(self, suffix: str = None, delete_on_exit: bool = True): + def temp_file_context(self, content: bytes = None, suffix: str = None, delete_on_exit: bool = True): """临时文件上下文管理器,确保自动清理""" - temp_fd, temp_path = tempfile.mkstemp(suffix=suffix) + temp_fd, temp_path = tempfile.mkstemp(suffix=suffix, dir=self.local_storage_dir) self.temp_files.add(temp_path) - try: + # 如果提供了内容,写入文件 + if content is not None: + with os.fdopen(temp_fd, 'wb') as f: + f.write(content) + else: os.close(temp_fd) # 关闭文件描述符 + + try: yield temp_path finally: if delete_on_exit and os.path.exists(temp_path): @@ -96,17 +107,13 @@ class OptimizedFileHandler: logger.error(f"❌ 图像验证失败: {filename}, {e}") return None - # 生成唯一ID和BOS键 + # 生成唯一ID file_id = str(uuid.uuid4()) - bos_key = f"images/memory_{file_id}_{filename}" - # 直接上传到BOS(从内存) - bos_result = self._upload_to_bos_from_memory( - file_content, bos_key, filename - ) - - if not bos_result: - return None + # 保存到本地存储 + local_path = os.path.join(self.local_storage_dir, f"{file_id}_{filename}") + with open(local_path, 'wb') as f: + f.write(file_content) # 存储元数据到MongoDB metadata = { @@ -115,18 +122,25 @@ class OptimizedFileHandler: "file_type": "image", "file_size": len(file_content), "processing_method": "memory", - "bos_key": bos_key, - "bos_url": bos_result["url"] + "local_path": local_path } - self.mongodb_manager.store_file_metadata(metadata=metadata) + # 如果有BOS管理器,也上传到BOS + if hasattr(self, 'bos_manager') and self.bos_manager: + bos_key = f"images/memory_{file_id}_{filename}" + bos_result = self._upload_to_bos_from_memory(file_content, bos_key, filename) + if bos_result: + metadata["bos_key"] = bos_key + metadata["bos_url"] = bos_result["url"] + + if hasattr(self, 'mongodb_manager') and self.mongodb_manager: + self.mongodb_manager.store_file_metadata(metadata=metadata) logger.info(f"✅ 内存处理图像成功: {filename} ({len(file_content)} bytes)") return { "file_id": file_id, "filename": filename, - "bos_key": bos_key, - "bos_result": bos_result, + "local_path": local_path, "processing_method": "memory" } @@ -140,6 +154,12 @@ class OptimizedFileHandler: # 获取文件扩展名 ext = os.path.splitext(filename)[1].lower() + # 生成唯一ID + file_id = str(uuid.uuid4()) + + # 创建永久文件路径 + permanent_path = os.path.join(self.local_storage_dir, f"{file_id}_{filename}") + with self.temp_file_context(suffix=ext) as temp_path: # 保存到临时文件 file_obj.seek(0) @@ -154,35 +174,41 @@ class OptimizedFileHandler: logger.error(f"❌ 图像验证失败: {filename}, {e}") return None - # 生成唯一ID和BOS键 - file_id = str(uuid.uuid4()) - bos_key = f"images/temp_{file_id}_{filename}" + # 复制到永久存储位置 + with open(temp_path, 'rb') as src, open(permanent_path, 'wb') as dst: + dst.write(src.read()) - # 上传到BOS - bos_result = self.bos_manager.upload_file(temp_path, bos_key) + # 获取文件信息 + file_stat = os.stat(permanent_path) - # 存储元数据到MongoDB - file_stat = os.stat(temp_path) + # 存储元数据 metadata = { "_id": file_id, "filename": filename, "file_type": "image", "file_size": file_stat.st_size, "processing_method": "temp_file", - "bos_key": bos_key, - "bos_url": bos_result["url"] + "local_path": permanent_path } - self.mongodb_manager.store_file_metadata(metadata=metadata) + # 如果有BOS管理器,也上传到BOS + if hasattr(self, 'bos_manager') and self.bos_manager: + bos_key = f"images/temp_{file_id}_{filename}" + bos_result = self.bos_manager.upload_file(temp_path, bos_key) + if bos_result: + metadata["bos_key"] = bos_key + metadata["bos_url"] = bos_result["url"] + + # 存储元数据到MongoDB + if hasattr(self, 'mongodb_manager') and self.mongodb_manager: + self.mongodb_manager.store_file_metadata(metadata=metadata) logger.info(f"✅ 临时文件处理图像成功: {filename} ({file_stat.st_size} bytes)") return { "file_id": file_id, "filename": filename, - "bos_key": bos_key, - "bos_result": bos_result, - "processing_method": "temp_file", - "temp_path": temp_path # 返回临时路径供模型处理 + "local_path": permanent_path, + "processing_method": "temp_file" } except Exception as e: @@ -290,8 +316,11 @@ class OptimizedFileHandler: try: ext = os.path.splitext(filename)[1].lower() + # 生成唯一ID + file_id = str(uuid.uuid4()) + # 创建临时文件(不自动删除,供模型使用) - temp_fd, temp_path = tempfile.mkstemp(suffix=ext) + temp_fd, temp_path = tempfile.mkstemp(suffix=ext, dir=self.local_storage_dir) self.temp_files.add(temp_path) try: diff --git a/templates/local_index.html b/templates/local_index.html new file mode 100644 index 0000000..e279750 --- /dev/null +++ b/templates/local_index.html @@ -0,0 +1,995 @@ + + + + + + 本地多模态检索系统 - FAISS + + + + + + +
+
+ 未初始化 +
+
+ +
+
+ +
+

本地多模态检索系统

+

基于本地模型和FAISS向量数据库,支持文搜图、文搜文、图搜图、图搜文四种检索模式

+
+ +
+ +
+ +
+ + +
+
+
+ +
文搜文
+

文本查找相似文本

+
+
+
+
+ +
文搜图
+

文本查找相关图片

+
+
+
+
+ +
图搜文
+

图片查找相关文本

+
+
+
+
+ +
图搜图
+

图片查找相似图片

+
+
+
+ + +
+
+
+
+
数据管理
+ 上传和管理检索数据库 +
+
+
+ +
+
+
批量上传图片
+
+ +

拖拽多张图片到此处或点击选择

+ + +
+ +
+
+ + +
+
+
批量上传文本
+
+ +
+
+ + + +
+
+
+
+ + +
+
+
+ + + +
+
+
+
+ + 图片: 0 张 | + 文本: 0 条 + +
+
+
+
+
+
+
+ + + + + +
+
+ Loading... +
+

正在搜索中...

+
+ + +
+
+
+
+ + + + + diff --git a/test_faiss_local.log b/test_faiss_local.log new file mode 100644 index 0000000..e69de29 diff --git a/test_faiss_simple.py b/test_faiss_simple.py new file mode 100644 index 0000000..d48a949 --- /dev/null +++ b/test_faiss_simple.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +FAISS多模态检索系统简单测试 +""" + +import sys +import os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from multimodal_retrieval_faiss import MultimodalRetrievalFAISS +from PIL import Image +import numpy as np + +def test_text_retrieval(): + print("=== 测试文本检索 ===") + + # 初始化检索系统 + print("初始化检索系统...") + retrieval = MultimodalRetrievalFAISS( + model_name="OpenSearch-AI/Ops-MM-embedding-v1-7B", + use_all_gpus=True, + index_path="faiss_index_test" + ) + + # 测试文本 + texts = [ + "一只可爱的橘色猫咪在沙发上睡觉", + "城市夜景中的高楼大厦和车流", + "阳光明媚的海滩上,人们在冲浪和晒太阳", + "美味的意大利面配红酒和沙拉", + "雪山上滑雪的运动员" + ] + + # 添加文本 + print("\n添加文本到检索系统...") + text_ids = retrieval.add_texts(texts) + print(f"添加了{len(text_ids)}条文本") + print(f"当前向量数量: {retrieval.get_vector_count()}") + + # 测试文本搜索 + print("\n测试文本搜索...") + queries = ["一只猫在睡觉", "都市风光", "海边的景色"] + + for query in queries: + print(f"\n查询: {query}") + results = retrieval.search_by_text(query, k=2) + for i, result in enumerate(results): + print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})") + + # 保存索引 + print("\n保存索引...") + retrieval.save_index() + + print("\n测试完成!") + +if __name__ == "__main__": + test_text_retrieval() diff --git a/test_faiss_with_proxy.py b/test_faiss_with_proxy.py new file mode 100644 index 0000000..d7af08b --- /dev/null +++ b/test_faiss_with_proxy.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +FAISS多模态检索系统简单测试 - 带代理设置 +""" + +import sys +import os +import logging + +# 设置代理 +os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890' # 根据实际情况修改 +os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890' # 根据实际情况修改 + +# 设置日志 +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# 设置离线模式,避免下载模型 +os.environ['TRANSFORMERS_OFFLINE'] = '1' + +# 添加当前目录到路径 +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +# 使用简单的向量模型替代大型多模态模型 +from sentence_transformers import SentenceTransformer +import faiss +import numpy as np + +class SimpleFaissRetrieval: + """简化版FAISS检索系统,使用sentence-transformers""" + + def __init__(self, model_name="paraphrase-multilingual-MiniLM-L12-v2", index_path="simple_faiss_index"): + """ + 初始化简化版检索系统 + + Args: + model_name: 模型名称,使用轻量级模型 + index_path: 索引文件路径 + """ + self.model_name = model_name + self.index_path = index_path + + logger.info(f"加载模型: {model_name}") + try: + # 尝试加载模型 + self.model = SentenceTransformer(model_name) + self.dimension = self.model.get_sentence_embedding_dimension() + logger.info(f"模型加载成功,向量维度: {self.dimension}") + except Exception as e: + logger.error(f"模型加载失败: {str(e)}") + logger.info("使用随机向量模拟...") + self.model = None + self.dimension = 384 # 默认维度 + + # 初始化索引 + self.index = faiss.IndexFlatL2(self.dimension) + self.metadata = {} + + logger.info("检索系统初始化完成") + + def encode_text(self, text): + """编码文本为向量""" + if self.model is None: + # 如果模型加载失败,使用随机向量 + if isinstance(text, list): + vectors = np.random.rand(len(text), self.dimension).astype('float32') + return vectors + else: + return np.random.rand(self.dimension).astype('float32') + else: + # 使用模型编码 + return self.model.encode(text, convert_to_numpy=True) + + def add_texts(self, texts, metadatas=None): + """添加文本到索引""" + if not texts: + return [] + + if metadatas is None: + metadatas = [{} for _ in range(len(texts))] + + # 编码文本 + vectors = self.encode_text(texts) + + # 添加到索引 + start_id = len(self.metadata) + ids = list(range(start_id, start_id + len(texts))) + + self.index.add(np.array(vectors).astype('float32')) + + # 保存元数据 + for i, id in enumerate(ids): + self.metadata[str(id)] = { + "text": texts[i], + "type": "text", + **metadatas[i] + } + + logger.info(f"添加了{len(ids)}条文本,当前索引大小: {self.index.ntotal}") + return [str(id) for id in ids] + + def search(self, query, k=5): + """搜索相似文本""" + # 编码查询 + query_vector = self.encode_text(query) + if len(query_vector.shape) == 1: + query_vector = query_vector.reshape(1, -1) + + # 搜索 + distances, indices = self.index.search(query_vector.astype('float32'), k) + + # 处理结果 + results = [] + for i in range(len(indices[0])): + idx = indices[0][i] + if idx < 0: + continue + + vector_id = str(idx) + if vector_id in self.metadata: + result = self.metadata[vector_id].copy() + result['score'] = float(1.0 / (1.0 + distances[0][i])) + results.append(result) + + return results + +def test_simple_retrieval(): + """测试简化版检索系统""" + print("=== 测试简化版FAISS检索系统 ===") + + # 初始化检索系统 + print("初始化检索系统...") + retrieval = SimpleFaissRetrieval() + + # 测试文本 + texts = [ + "一只可爱的橘色猫咪在沙发上睡觉", + "城市夜景中的高楼大厦和车流", + "阳光明媚的海滩上,人们在冲浪和晒太阳", + "美味的意大利面配红酒和沙拉", + "雪山上滑雪的运动员" + ] + + # 添加文本 + print("\n添加文本到检索系统...") + text_ids = retrieval.add_texts(texts) + print(f"添加了{len(text_ids)}条文本") + + # 测试文本搜索 + print("\n测试文本搜索...") + queries = ["一只猫在睡觉", "都市风光", "海边的景色"] + + for query in queries: + print(f"\n查询: {query}") + results = retrieval.search(query, k=2) + for i, result in enumerate(results): + print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})") + + print("\n测试完成!") + +if __name__ == "__main__": + test_simple_retrieval() diff --git a/test_fixes.py b/test_fixes.py new file mode 100644 index 0000000..f8e29e9 --- /dev/null +++ b/test_fixes.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +测试修复后的系统功能 +""" + +import requests +import time +import json + +def test_system(): + """测试系统功能""" + base_url = "http://localhost:5000" + + print("🧪 开始测试修复后的系统...") + print("=" * 50) + + # 测试1: 检查系统状态 + print("1. 测试系统状态...") + try: + response = requests.get(f"{base_url}/api/status", timeout=10) + if response.status_code == 200: + status = response.json() + print(f" ✅ 系统状态: {status}") + else: + print(f" ❌ 状态检查失败: {response.status_code}") + except Exception as e: + print(f" ❌ 状态检查异常: {e}") + + # 测试2: 检查数据统计 + print("\n2. 测试数据统计...") + try: + response = requests.get(f"{base_url}/api/data/stats", timeout=10) + if response.status_code == 200: + stats = response.json() + print(f" ✅ 数据统计: {stats}") + else: + print(f" ❌ 统计检查失败: {response.status_code}") + except Exception as e: + print(f" ❌ 统计检查异常: {e}") + + # 测试3: 检查数据列表 + print("\n3. 测试数据列表...") + try: + response = requests.get(f"{base_url}/api/data/list", timeout=10) + if response.status_code == 200: + data_list = response.json() + print(f" ✅ 数据列表: {data_list}") + else: + print(f" ❌ 列表检查失败: {response.status_code}") + except Exception as e: + print(f" ❌ 列表检查异常: {e}") + + # 测试4: 测试文本搜索(如果系统已初始化) + print("\n4. 测试文本搜索...") + try: + search_data = { + "query": "测试查询", + "top_k": 3 + } + response = requests.post(f"{base_url}/api/search/text_to_text", + json=search_data, timeout=10) + if response.status_code == 200: + result = response.json() + print(f" ✅ 文本搜索: {result}") + else: + print(f" ❌ 文本搜索失败: {response.status_code}") + except Exception as e: + print(f" ❌ 文本搜索异常: {e}") + + print("\n" + "=" * 50) + print("🎉 测试完成!") + +if __name__ == "__main__": + # 等待系统启动 + print("⏳ 等待系统启动...") + time.sleep(5) + + test_system() diff --git a/test_local_model.py b/test_local_model.py new file mode 100644 index 0000000..f3f370a --- /dev/null +++ b/test_local_model.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +使用本地模型的FAISS多模态检索系统测试 +""" + +import os +import sys +import logging +from pathlib import Path +import numpy as np +import faiss +from typing import List, Dict, Any, Optional, Union +import json + +# 设置日志 +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# 设置离线模式 +os.environ['TRANSFORMERS_OFFLINE'] = '1' + +def test_local_model(): + """测试本地模型加载""" + from transformers import AutoModel, AutoTokenizer, AutoProcessor + import torch + from PIL import Image + + # 这里替换为您实际下载的模型路径 + local_model_path = "/root/models/Ops-MM-embedding-v1-7B" + + if not os.path.exists(local_model_path): + logger.error(f"模型路径不存在: {local_model_path}") + logger.info("请先下载模型到指定路径") + return + + logger.info(f"加载本地模型: {local_model_path}") + + try: + # 加载tokenizer + logger.info("加载tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(local_model_path) + + # 加载processor + logger.info("加载processor...") + processor = AutoProcessor.from_pretrained(local_model_path) + + # 加载模型 + logger.info("加载模型...") + model = AutoModel.from_pretrained( + local_model_path, + torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, + device_map="auto" if torch.cuda.device_count() > 0 else None + ) + + logger.info("模型加载成功!") + + # 测试文本编码 + logger.info("测试文本编码...") + text = "这是一个测试文本" + inputs = tokenizer(text, return_tensors="pt") + if torch.cuda.is_available(): + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + with torch.no_grad(): + outputs = model(**inputs) + text_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy() + + logger.info(f"文本编码维度: {text_embedding.shape}") + + # 如果有图像处理功能,测试图像编码 + try: + logger.info("测试图像编码...") + # 创建一个简单的测试图像 + image = Image.new('RGB', (224, 224), color='red') + image_inputs = processor(images=image, return_tensors="pt") + + if torch.cuda.is_available(): + image_inputs = {k: v.to("cuda") for k, v in image_inputs.items()} + + with torch.no_grad(): + image_outputs = model.vision_model(**image_inputs) + image_embedding = image_outputs.pooler_output.cpu().numpy() + + logger.info(f"图像编码维度: {image_embedding.shape}") + + except Exception as e: + logger.error(f"图像编码测试失败: {str(e)}") + + logger.info("本地模型测试完成!") + + except Exception as e: + logger.error(f"模型加载失败: {str(e)}") + logger.error("请确保模型文件已正确下载") + +if __name__ == "__main__": + test_local_model() diff --git a/test_local_retrieval.py b/test_local_retrieval.py new file mode 100644 index 0000000..19ca50b --- /dev/null +++ b/test_local_retrieval.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +测试本地模型和FAISS向量数据库的多模态检索系统 +""" + +import os +import sys +import logging +from pathlib import Path +import time +from PIL import Image +import numpy as np +from multimodal_retrieval_local import MultimodalRetrievalLocal + +# 设置日志 +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# 设置离线模式 +os.environ['TRANSFORMERS_OFFLINE'] = '1' + +def test_text_retrieval(): + """测试文本检索功能""" + print("\n=== 测试文本检索 ===") + + # 初始化检索系统 + print("初始化检索系统...") + retrieval = MultimodalRetrievalLocal( + model_path="/root/models/Ops-MM-embedding-v1-7B", + use_all_gpus=True, + index_path="local_faiss_text_test" + ) + + # 测试文本 + texts = [ + "一只可爱的橘色猫咪在沙发上睡觉", + "城市夜景中的高楼大厦和车流", + "阳光明媚的海滩上,人们在冲浪和晒太阳", + "美味的意大利面配红酒和沙拉", + "雪山上滑雪的运动员" + ] + + # 添加文本 + print("\n添加文本到检索系统...") + text_ids = retrieval.add_texts(texts) + print(f"添加了{len(text_ids)}条文本") + + # 获取统计信息 + stats = retrieval.get_stats() + print(f"检索系统统计信息: {stats}") + + # 测试文本搜索 + print("\n测试文本搜索...") + queries = ["一只猫在睡觉", "都市风光", "海边的景色"] + + for query in queries: + print(f"\n查询: {query}") + results = retrieval.search_by_text(query, k=2) + for i, result in enumerate(results): + print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})") + + # 保存索引 + print("\n保存索引...") + retrieval.save_index() + + print("\n文本检索测试完成!") + return retrieval + +def test_image_retrieval(): + """测试图像检索功能""" + print("\n=== 测试图像检索 ===") + + # 初始化检索系统 + print("初始化检索系统...") + retrieval = MultimodalRetrievalLocal( + model_path="/root/models/Ops-MM-embedding-v1-7B", + use_all_gpus=True, + index_path="local_faiss_image_test" + ) + + # 创建测试图像 + print("\n创建测试图像...") + images = [] + colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)] + image_paths = [] + + for i, color in enumerate(colors): + img = Image.new('RGB', (224, 224), color=color) + images.append(img) + + # 保存图像 + img_path = f"/tmp/test_image_{i}.png" + img.save(img_path) + image_paths.append(img_path) + print(f"创建图像: {img_path}") + + # 添加图像 + print("\n添加图像到检索系统...") + metadatas = [{"description": f"测试图像 {i+1}"} for i in range(len(images))] + image_ids = retrieval.add_images(images, metadatas, image_paths) + print(f"添加了{len(image_ids)}张图像") + + # 获取统计信息 + stats = retrieval.get_stats() + print(f"检索系统统计信息: {stats}") + + # 测试图像搜索 + print("\n测试图像搜索...") + query_image = Image.new('RGB', (224, 224), color=(255, 0, 0)) # 红色图像 + + print("\n使用图像查询图像:") + results = retrieval.search_by_image(query_image, k=2, filter_type="image") + for i, result in enumerate(results): + print(f" 结果 {i+1}: {result.get('description', 'N/A')} (分数: {result.get('score', 0):.4f})") + + # 保存索引 + print("\n保存索引...") + retrieval.save_index() + + print("\n图像检索测试完成!") + return retrieval + +def test_cross_modal_retrieval(): + """测试跨模态检索功能""" + print("\n=== 测试跨模态检索 ===") + + # 初始化检索系统 + print("初始化检索系统...") + retrieval = MultimodalRetrievalLocal( + model_path="/root/models/Ops-MM-embedding-v1-7B", + use_all_gpus=True, + index_path="local_faiss_cross_modal_test" + ) + + # 添加文本 + texts = [ + "一只红色的苹果", + "绿色的草地", + "蓝色的大海", + "黄色的向日葵", + "青色的天空" + ] + print("\n添加文本到检索系统...") + text_ids = retrieval.add_texts(texts) + print(f"添加了{len(text_ids)}条文本") + + # 添加图像 + print("\n添加图像到检索系统...") + images = [] + colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)] + descriptions = ["红色图像", "绿色图像", "蓝色图像", "黄色图像", "青色图像"] + + for i, color in enumerate(colors): + img = Image.new('RGB', (224, 224), color=color) + images.append(img) + + metadatas = [{"description": desc} for desc in descriptions] + image_ids = retrieval.add_images(images, metadatas) + print(f"添加了{len(image_ids)}张图像") + + # 获取统计信息 + stats = retrieval.get_stats() + print(f"检索系统统计信息: {stats}") + + # 测试文搜图 + print("\n测试文搜图...") + query_text = "红色" + print(f"查询文本: {query_text}") + results = retrieval.search_by_text(query_text, k=2, filter_type="image") + for i, result in enumerate(results): + print(f" 结果 {i+1}: {result.get('description', 'N/A')} (分数: {result.get('score', 0):.4f})") + + # 测试图搜文 + print("\n测试图搜文...") + query_image = Image.new('RGB', (224, 224), color=(0, 0, 255)) # 蓝色图像 + print("查询图像: 蓝色图像") + results = retrieval.search_by_image(query_image, k=2, filter_type="text") + for i, result in enumerate(results): + print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})") + + # 保存索引 + print("\n保存索引...") + retrieval.save_index() + + print("\n跨模态检索测试完成!") + return retrieval + +def main(): + """主函数""" + print("=== 本地多模态检索系统测试 ===") + + # 检查模型路径 + model_path = "/root/models/Ops-MM-embedding-v1-7B" + if not os.path.exists(model_path): + print(f"错误: 模型路径不存在: {model_path}") + print("请先下载模型到指定路径") + return + + # 检查模型文件 + config_file = os.path.join(model_path, "config.json") + if not os.path.exists(config_file): + print(f"错误: 模型配置文件不存在: {config_file}") + print("请确保模型文件已正确下载") + return + + print(f"模型路径验证成功: {model_path}") + + # 运行测试 + try: + # 测试文本检索 + test_text_retrieval() + + # 测试图像检索 + test_image_retrieval() + + # 测试跨模态检索 + test_cross_modal_retrieval() + + print("\n所有测试完成!") + + except Exception as e: + print(f"测试过程中发生错误: {str(e)}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + main() diff --git a/web_app.log b/web_app.log new file mode 100644 index 0000000..0cc7af7 --- /dev/null +++ b/web_app.log @@ -0,0 +1,63 @@ +nohup: ignoring input +INFO:__main__:🚀 启动时自动初始化VDB多模态检索系统... +INFO:multimodal_retrieval_vdb:检测到 2 个GPU +INFO:multimodal_retrieval_vdb:使用GPU: [0, 1], 主设备: cuda:0 +INFO:multimodal_retrieval_vdb:GPU内存已清理 +INFO:multimodal_retrieval_vdb:正在加载模型到GPU: [0, 1] +INFO:multimodal_retrieval_vdb:GPU内存已清理 +🚀 启动VDB多模态检索Web应用 +============================================================ +访问地址: http://localhost:5000 +新功能: + 🗄️ 百度VDB - 向量数据库存储 + 📊 实时统计 - VDB数据统计信息 + 🔄 数据同步 - 本地文件到VDB存储 +支持功能: + 📝 文搜文 - 文本查找相似文本 + 🖼️ 文搜图 - 文本查找相关图片 + 📝 图搜文 - 图片查找相关文本 + 🖼️ 图搜图 - 图片查找相似图片 + 📤 批量上传 - 图片和文本数据管理 +GPU配置: + 🖥️ 检测到 2 个GPU + GPU 0: NVIDIA GeForce RTX 4090 (23.6GB) + GPU 1: NVIDIA GeForce RTX 4090 (23.6GB) +VDB配置: + 🌐 服务器: http://180.76.96.191:5287 + 👤 用户: root + 🗄️ 数据库: multimodal_retrieval +============================================================ + Loading checkpoint shards: 0%| | 0/4 [00:00: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 103ac836-6599-4fe2-a569-aed9c945525c)') +The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release. +WARNING:multimodal_retrieval_vdb:Processor加载失败,使用tokenizer: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/models/OpenSearch-AI/Ops-MM-embedding-v1-7B/tree/main/additional_chat_templates?recursive=False&expand=False (Caused by NewConnectionError(': Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 96f18121-7beb-4e1a-87cd-c50edf682933)') +You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0. +INFO:multimodal_retrieval_vdb:模型加载完成 +INFO:baidu_vdb_backend:✅ 成功连接到百度VDB: http://180.76.96.191:5287 +INFO:baidu_vdb_backend:使用现有数据库: multimodal_retrieval +INFO:baidu_vdb_backend:创建文本向量表: text_vectors +ERROR:baidu_vdb_backend:❌ 创建文本表失败: Database.create_table() missing 1 required positional argument: 'partition' +ERROR:baidu_vdb_backend:❌ 表操作失败: Database.create_table() missing 1 required positional argument: 'partition' +ERROR:multimodal_retrieval_vdb:❌ VDB后端初始化失败: Database.create_table() missing 1 required positional argument: 'partition' +WARNING:multimodal_retrieval_vdb:⚠️ 系统将在无VDB模式下运行,数据将不会持久化 +INFO:multimodal_retrieval_vdb:多模态检索系统初始化完成 +ERROR:__main__:❌ VDB系统自动初始化失败: VDB连接失败 +ERROR:__main__:Traceback (most recent call last): + File "/root/mmeb/web_app_vdb.py", line 667, in auto_initialize + raise Exception("VDB连接失败") +Exception: VDB连接失败 + + * Serving Flask app 'web_app_vdb' + * Debug mode: off +Address already in use +Port 5000 is in use by another program. Either identify and stop that program, or start the server with a different port. +败 +ERROR:__main__:Traceback (most recent call last): + File "/root/mmeb/web_app_vdb.py", line 664, in auto_initialize + raise Exception("模型加载失败") +Exception: 模型加载失败 + + * Serving Flask app 'web_app_vdb' + * Debug mode: off +Address already in use +Port 5000 is in use by another program. Either identify and stop that program, or start the server with a different port. diff --git a/web_app_local.py b/web_app_local.py new file mode 100644 index 0000000..d6bd994 --- /dev/null +++ b/web_app_local.py @@ -0,0 +1,466 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +本地多模态检索系统Web应用 +集成本地模型和FAISS向量数据库 +支持文搜文、文搜图、图搜文、图搜图四种检索模式 +""" + +import os +import sys +import logging +import time +import json +import base64 +from io import BytesIO +from pathlib import Path +import numpy as np +from PIL import Image +from flask import Flask, request, jsonify, render_template, send_from_directory +from werkzeug.utils import secure_filename +import torch + +# 设置离线模式 +os.environ['TRANSFORMERS_OFFLINE'] = '1' + +# 导入本地模块 +from multimodal_retrieval_local import MultimodalRetrievalLocal +from optimized_file_handler import OptimizedFileHandler + +# 设置日志 +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# 创建Flask应用 +app = Flask(__name__) + +# 配置 +app.config['UPLOAD_FOLDER'] = '/tmp/mmeb_uploads' +app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB +app.config['MODEL_PATH'] = '/root/models/Ops-MM-embedding-v1-7B' +app.config['INDEX_PATH'] = '/root/mmeb/local_faiss_index' +app.config['ALLOWED_EXTENSIONS'] = {'txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'} + +# 确保上传目录存在 +os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) + +# 创建临时文件夹 +if not os.path.exists(app.config['UPLOAD_FOLDER']): + os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) + +# 创建文件处理器 +from optimized_file_handler import OptimizedFileHandler +file_handler = OptimizedFileHandler(local_storage_dir=app.config['UPLOAD_FOLDER']) + +# 全局变量 +retrieval_system = None + +def allowed_file(filename): + """检查文件扩展名是否允许""" + return '.' in filename and \ + filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS'] + +def init_retrieval_system(): + """初始化检索系统""" + global retrieval_system + + if retrieval_system is not None: + return retrieval_system + + logger.info("初始化多模态检索系统...") + + # 检查模型路径 + model_path = app.config['MODEL_PATH'] + if not os.path.exists(model_path): + logger.error(f"模型路径不存在: {model_path}") + raise FileNotFoundError(f"模型路径不存在: {model_path}") + + # 初始化检索系统 + retrieval_system = MultimodalRetrievalLocal( + model_path=model_path, + use_all_gpus=True, + index_path=app.config['INDEX_PATH'] + ) + + logger.info("多模态检索系统初始化完成") + return retrieval_system + +def get_image_base64(image_path): + """将图像转换为base64编码""" + with open(image_path, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()).decode('utf-8') + return f"data:image/jpeg;base64,{encoded_string}" + +@app.route('/') +def index(): + """首页""" + return render_template('local_index.html') + +@app.route('/api/stats', methods=['GET']) +def get_stats(): + """获取系统统计信息""" + try: + retrieval = init_retrieval_system() + stats = retrieval.get_stats() + return jsonify({"success": True, "stats": stats}) + except Exception as e: + logger.error(f"获取统计信息失败: {str(e)}") + return jsonify({"success": False, "error": str(e)}), 500 + +@app.route('/api/add_text', methods=['POST']) +def add_text(): + """添加文本""" + try: + data = request.json + text = data.get('text') + + if not text: + return jsonify({"success": False, "error": "文本不能为空"}), 400 + + # 使用内存处理文本 + with file_handler.temp_file_context(text.encode('utf-8'), suffix='.txt') as temp_file: + logger.info(f"处理文本: {temp_file}") + + # 初始化检索系统 + retrieval = init_retrieval_system() + + # 添加文本 + metadata = { + "timestamp": time.time(), + "source": "web_upload" + } + + text_ids = retrieval.add_texts([text], [metadata]) + + # 保存索引 + retrieval.save_index() + + return jsonify({ + "success": True, + "message": "文本添加成功", + "text_id": text_ids[0] if text_ids else None + }) + + except Exception as e: + logger.error(f"添加文本失败: {str(e)}") + return jsonify({"success": False, "error": str(e)}), 500 + finally: + # 清理临时文件 + file_handler.cleanup_all_temp_files() + +@app.route('/api/add_image', methods=['POST']) +def add_image(): + """添加图像""" + try: + # 检查是否有文件 + if 'image' not in request.files: + return jsonify({"success": False, "error": "没有上传文件"}), 400 + + file = request.files['image'] + + # 检查文件名 + if file.filename == '': + return jsonify({"success": False, "error": "没有选择文件"}), 400 + + if file and allowed_file(file.filename): + # 读取图像数据 + image_data = file.read() + file_size = len(image_data) + + # 使用文件处理器处理图像 + logger.info(f"处理图像: {file.filename} ({file_size} 字节)") + + # 初始化检索系统 + retrieval = init_retrieval_system() + + # 创建临时文件 + file_obj = BytesIO(image_data) + filename = secure_filename(file.filename) + + # 保存到本地文件系统 + image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) + with open(image_path, 'wb') as f: + f.write(image_data) + + # 加载图像 + try: + image = Image.open(BytesIO(image_data)) + # 确保图像是RGB模式 + if image.mode != 'RGB': + logger.info(f"将图像从 {image.mode} 转换为 RGB") + image = image.convert('RGB') + + logger.info(f"成功加载图像: {filename}, 格式: {image.format}, 模式: {image.mode}, 大小: {image.size}") + except Exception as e: + logger.error(f"加载图像失败: {filename}, 错误: {str(e)}") + return jsonify({"success": False, "error": f"图像格式不支持: {str(e)}"}), 400 + + # 添加图像 + metadata = { + "filename": filename, + "timestamp": time.time(), + "source": "web_upload", + "size": file_size, + "local_path": image_path + } + + # 添加到检索系统 + image_ids = retrieval.add_images([image], [metadata], [image_path]) + + # 保存索引 + retrieval.save_index() + + return jsonify({ + "success": True, + "message": "图像添加成功", + "image_id": image_ids[0] if image_ids else None + }) + else: + return jsonify({"success": False, "error": "不支持的文件类型"}), 400 + + except Exception as e: + logger.error(f"添加图像失败: {str(e)}") + return jsonify({"success": False, "error": str(e)}), 500 + finally: + # 清理临时文件 + file_handler.cleanup_all_temp_files() + +@app.route('/api/search_by_text', methods=['POST']) +def search_by_text(): + """文本搜索""" + try: + data = request.json + query = data.get('query') + k = int(data.get('k', 5)) + filter_type = data.get('filter_type') # "text", "image" 或 null + + if not query: + return jsonify({"success": False, "error": "查询文本不能为空"}), 400 + + # 初始化检索系统 + retrieval = init_retrieval_system() + + # 执行搜索 + results = retrieval.search_by_text(query, k, filter_type) + + # 处理结果 + processed_results = [] + for result in results: + item = { + "score": result.get("score", 0), + "type": result.get("type") + } + + if result.get("type") == "text": + item["text"] = result.get("text", "") + elif result.get("type") == "image": + if "path" in result and os.path.exists(result["path"]): + item["image"] = get_image_base64(result["path"]) + item["filename"] = os.path.basename(result["path"]) + if "description" in result: + item["description"] = result["description"] + + processed_results.append(item) + + return jsonify({ + "success": True, + "results": processed_results, + "query": query, + "filter_type": filter_type + }) + + except Exception as e: + logger.error(f"文本搜索失败: {str(e)}") + return jsonify({"success": False, "error": str(e)}), 500 + +@app.route('/api/search_by_image', methods=['POST']) +def search_by_image(): + """图像搜索""" + try: + # 检查是否有文件 + if 'image' not in request.files: + return jsonify({"success": False, "error": "没有上传文件"}), 400 + + file = request.files['image'] + k = int(request.form.get('k', 5)) + filter_type = request.form.get('filter_type') # "text", "image" 或 null + + # 检查文件名 + if file.filename == '': + return jsonify({"success": False, "error": "没有选择文件"}), 400 + + if file and allowed_file(file.filename): + # 读取图像数据 + image_data = file.read() + file_size = len(image_data) + + # 根据文件大小选择处理方式 + if file_size <= 5 * 1024 * 1024: # 5MB + # 小文件使用内存处理 + logger.info(f"使用内存处理搜索图像: {file.filename} ({file_size} 字节)") + image = Image.open(BytesIO(image_data)) + + # 初始化检索系统 + retrieval = init_retrieval_system() + + # 执行搜索 + results = retrieval.search_by_image(image, k, filter_type) + else: + # 大文件使用临时文件处理 + with file_handler.temp_file_context(image_data, suffix=os.path.splitext(file.filename)[1]) as temp_file: + logger.info(f"使用临时文件处理搜索图像: {temp_file} ({file_size} 字节)") + + # 初始化检索系统 + retrieval = init_retrieval_system() + + # 加载图像 + image = Image.open(temp_file) + + # 执行搜索 + results = retrieval.search_by_image(image, k, filter_type) + + # 处理结果 + processed_results = [] + for result in results: + item = { + "score": result.get("score", 0), + "type": result.get("type") + } + + if result.get("type") == "text": + item["text"] = result.get("text", "") + elif result.get("type") == "image": + if "path" in result and os.path.exists(result["path"]): + item["image"] = get_image_base64(result["path"]) + item["filename"] = os.path.basename(result["path"]) + if "description" in result: + item["description"] = result["description"] + + processed_results.append(item) + + return jsonify({ + "success": True, + "results": processed_results, + "filter_type": filter_type + }) + else: + return jsonify({"success": False, "error": "不支持的文件类型"}), 400 + + except Exception as e: + logger.error(f"图像搜索失败: {str(e)}") + return jsonify({"success": False, "error": str(e)}), 500 + finally: + # 清理临时文件 + file_handler.cleanup_all_temp_files() + +@app.route('/uploads/') +def uploaded_file(filename): + """提供上传文件的访问""" + return send_from_directory(app.config['UPLOAD_FOLDER'], filename) + +@app.route('/temp/') +def temp_file(filename): + """提供临时文件的访问""" + return send_from_directory(app.config['UPLOAD_FOLDER'], filename) + +@app.route('/api/save_index', methods=['POST']) +def save_index(): + """保存索引""" + try: + # 初始化检索系统 + retrieval = init_retrieval_system() + + # 保存索引 + retrieval.save_index() + + return jsonify({ + "success": True, + "message": "索引保存成功" + }) + + except Exception as e: + logger.error(f"保存索引失败: {str(e)}") + return jsonify({"success": False, "error": str(e)}), 500 + +@app.route('/api/clear_index', methods=['POST']) +def clear_index(): + """清空索引""" + try: + # 初始化检索系统 + retrieval = init_retrieval_system() + + # 清空索引 + retrieval.clear_index() + + return jsonify({ + "success": True, + "message": "索引已清空" + }) + + except Exception as e: + logger.error(f"清空索引失败: {str(e)}") + return jsonify({"success": False, "error": str(e)}), 500 + +@app.route('/api/list_items', methods=['GET']) +def list_items(): + """列出所有索引项""" + try: + # 初始化检索系统 + retrieval = init_retrieval_system() + + # 获取所有项 + items = retrieval.list_items() + + return jsonify({ + "success": True, + "items": items + }) + + except Exception as e: + logger.error(f"列出索引项失败: {str(e)}") + return jsonify({"success": False, "error": str(e)}), 500 + +@app.route('/api/system_info', methods=['GET', 'POST']) +def system_info(): + """获取系统信息""" + try: + # GPU信息 + gpu_info = [] + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + gpu_info.append({ + "id": i, + "name": torch.cuda.get_device_name(i), + "memory_total": torch.cuda.get_device_properties(i).total_memory / (1024 ** 3), + "memory_allocated": torch.cuda.memory_allocated(i) / (1024 ** 3), + "memory_reserved": torch.cuda.memory_reserved(i) / (1024 ** 3) + }) + + # 检索系统信息 + retrieval_info = {} + if retrieval_system is not None: + retrieval_info = retrieval_system.get_stats() + + return jsonify({ + "success": True, + "gpu_info": gpu_info, + "retrieval_info": retrieval_info, + "model_path": app.config['MODEL_PATH'], + "index_path": app.config['INDEX_PATH'] + }) + + except Exception as e: + logger.error(f"获取系统信息失败: {str(e)}") + return jsonify({"success": False, "error": str(e)}), 500 + +if __name__ == '__main__': + try: + # 预初始化检索系统 + init_retrieval_system() + + # 启动Web应用 + app.run(host='0.0.0.0', port=5000, debug=False) + except Exception as e: + logger.error(f"启动Web应用失败: {str(e)}") + sys.exit(1) diff --git a/web_app_vdb.py b/web_app_vdb.py index a8e1ece..34bc6cd 100644 --- a/web_app_vdb.py +++ b/web_app_vdb.py @@ -514,6 +514,57 @@ def get_data_stats(): 'message': f'获取统计失败: {str(e)}' }), 500 +@app.route('/api/data/list', methods=['GET']) +def list_data(): + """获取数据列表""" + try: + # 获取图片文件列表 + image_files = [] + for ext in ALLOWED_EXTENSIONS: + pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}") + for file_path in glob.glob(pattern): + try: + # 转换为base64 + image_base64 = image_to_base64(file_path) + image_files.append({ + 'filename': os.path.basename(file_path), + 'filepath': file_path, + 'image_base64': image_base64, + 'size': os.path.getsize(file_path) + }) + except Exception as e: + logger.warning(f"处理图片文件失败 {file_path}: {e}") + + # 获取文本文件列表 + text_files = [] + text_file_paths = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json")) + text_file_paths.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))) + + for text_file in text_file_paths: + try: + text_files.append({ + 'filename': os.path.basename(text_file), + 'filepath': text_file, + 'size': os.path.getsize(text_file) + }) + except Exception as e: + logger.warning(f"处理文本文件失败 {text_file}: {e}") + + return jsonify({ + 'success': True, + 'image_files': image_files, + 'text_files': text_files, + 'image_count': len(image_files), + 'text_count': len(text_files) + }) + + except Exception as e: + logger.error(f"获取数据列表失败: {str(e)}") + return jsonify({ + 'success': False, + 'message': f'获取数据列表失败: {str(e)}' + }), 500 + @app.route('/api/data/clear', methods=['POST']) def clear_data(): """清空所有数据"""