活动介绍

训练模型 RuntimeError: Numpy is not available

时间: 2025-01-09 09:53:34 浏览: 995
### 解决RuntimeError: Numpy not available 错误 当遇到 `RuntimeError: Numpy not available` 这样的错误提示时,表明当前环境中未能成功加载NumPy库。这可能是由于多种因素引起的,包括但不限于环境配置不当、依赖冲突或是某些必要的编译工具缺失。 #### 1. 安装或更新 NumPy 库 如果尚未安装NumPy,则可以通过 pip 或 conda 来安装: ```bash pip install numpy ``` 对于使用 Anaconda 发行版的用户而言,推荐通过 Conda 渠道来获取最新版本的 NumPy: ```bash conda update numpy ``` #### 2. 验证 Python 和 NumPy 版本兼容性 确保所使用的Python解释器与NumPy版本相匹配非常重要。不同版本之间可能存在API变化或其他不兼容之处。建议查阅官方文档确认支持情况并据此调整项目依赖关系。 #### 3. 检查虚拟环境设置 有时,在激活特定虚拟环境下运行程序可能导致此类问题的发生。这是因为该环境中缺少所需的包或者是路径变量被修改所致。可以尝试重新创建一个新的干净虚拟环境,并在此基础上重复上述安装步骤。 #### 4. 确认 C 编译器存在与否 部分情况下,NumPy需要C编译器来进行内部组件构建。Windows 用户可能还需要额外下载 Microsoft Visual Studio Build Tools;而对于 Linux/Unix 类操作系统则需保证 gcc 已经正确安装。 ```bash sudo apt-get install build-essential ``` 以上措施有助于解决因缺乏适当开发工具而导致的问题[^1]。
阅读全文

相关推荐

纠错5: python Python 3.10.9 (tags/v3.10.9:1dd9be6, Dec 6 2022, 20:01:21) [MSC v.1934 64 bit (AMD64)] on win32 Type "help", "copyright", "credits" or "license" for more information. >>> import whisper >>> from transformers import MarianMTModel, MarianTokenizer >>> whisper.load_model("small") A module that was compiled using NumPy 1.x cannot be run in NumPy 2.1.3 as it may crash. To support both 1.x and 2.x versions of NumPy, modules must be compiled with NumPy 2.0. Some module may need to rebuild instead e.g. with 'pybind11>=2.12'. If you are a user of the module, the easiest solution will be to downgrade to 'numpy<2' or try to upgrade the affected module. We expect that some modules will need time to support NumPy 2. Traceback (most recent call last): File "<stdin>", line 1, in <module> File "D:\bili_translator\venv\lib\site-packages\whisper\__init__.py", line 150, in load_model checkpoint = torch.load(fp, map_location=device) File "D:\bili_translator\venv\lib\site-packages\torch\serialization.py", line 809, in load return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args) File "D:\bili_translator\venv\lib\site-packages\torch\serialization.py", line 1172, in _load result = unpickler.load() File "D:\bili_translator\venv\lib\site-packages\torch\_utils.py", line 169, in _rebuild_tensor_v2 tensor = _rebuild_tensor(storage, storage_offset, size, stride) File "D:\bili_translator\venv\lib\site-packages\torch\_utils.py", line 147, in _rebuild_tensor t = torch.tensor([], dtype=storage.dtype, device=storage._untyped_storage.device) D:\bili_translator\venv\lib\site-packages\torch\_utils.py:147: UserWarning: Failed to initialize NumPy: _ARRAY_API not found (Triggered internally at ..\torch\csrc\utils\tensor_numpy.cpp:84.) t = torch.tensor([], dtype=storage.dtype, device=storage._untyped_storage.device) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "D:\bili_translator\venv\lib\site-packages\whisper\__init__.py", line 158, in load_model model.set_alignment_heads(alignment_heads) File "D:\bili_translator\venv\lib\site-packages\whisper\model.py", line 282, in set_alignment_heads mask = torch.from_numpy(array).reshape( RuntimeError: Numpy is not available

A module that was compiled using NumPy 1.x cannot be run in NumPy 2.0.1 as it may crash. To support both 1.x and 2.x versions of NumPy, modules must be compiled with NumPy 2.0. Some module may need to rebuild instead e.g. with 'pybind11>=2.12'. If you are a user of the module, the easiest solution will be to downgrade to 'numpy<2' or try to upgrade the affected module. We expect that some modules will need time to support NumPy 2. Traceback (most recent call last): File "D:\pythonProject1\python.py", line 5, in <module> from torchvision import datasets, transforms File "C:\Users\Administrator\.conda\envs\pytorch\lib\site-packages\torchvision\__init__.py", line 5, in <module> from torchvision import datasets, io, models, ops, transforms, utils File "C:\Users\Administrator\.conda\envs\pytorch\lib\site-packages\torchvision\models\__init__.py", line 17, in <module> from . import detection, optical_flow, quantization, segmentation, video File "C:\Users\Administrator\.conda\envs\pytorch\lib\site-packages\torchvision\models\detection\__init__.py", line 1, in <module> from .faster_rcnn import * File "C:\Users\Administrator\.conda\envs\pytorch\lib\site-packages\torchvision\models\detection\faster_rcnn.py", line 16, in <module> from .anchor_utils import AnchorGenerator File "C:\Users\Administrator\.conda\envs\pytorch\lib\site-packages\torchvision\models\detection\anchor_utils.py", line 10, in <module> class AnchorGenerator(nn.Module): File "C:\Users\Administrator\.conda\envs\pytorch\lib\site-packages\torchvision\models\detection\anchor_utils.py", line 63, in AnchorGenerator device: torch.device = torch.device("cpu"), C:\Users\Administrator\.conda\envs\pytorch\lib\site-packages\torchvision\models\detection\anchor_utils.py:63: UserWarning: Failed to initialize NumPy: _ARRAY_API not found (Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\utils\tensor_numpy.cpp:77.) device: torch.device = torch.device("cpu"), Traceback (most recent call last): File "D:\pythonProject1\python.py", line 55, in <module> for images, labels in train_loader: File "C:\Users\Administrator\.conda\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 628, in __next__ data = self._next_data() File "C:\Users\Administrator\.conda\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 671, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "C:\Users\Administrator\.conda\envs\pytorch\lib\site-packages\torch\utils\data\_utils\fetch.py", line 58, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "C:\Users\Administrator\.conda\envs\pytorch\lib\site-packages\torch\utils\data\_utils\fetch.py", line 58, in data = [self.dataset[idx] for idx in possibly_batched_index] File "C:\Users\Administrator\.conda\envs\pytorch\lib\site-packages\torchvision\datasets\mnist.py", line 142, in __getitem__ img = Image.fromarray(img.numpy(), mode="L") RuntimeError: Numpy is not available

Traceback (most recent call last): File "D:\Download\codeseg\code\demo_test_image.py", line 100, in <module> model.load_model("./weights/yolov8s-seg.pt") File "D:\Download\codeseg\code\model.py", line 57, in load_model self.model(torch.zeros(1, 3, *[self.imgsz] * 2).to(self.device). File "D:\Download\codeseg\code\ultralytics\engine\model.py", line 102, in __call__ return self.predict(source, stream, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\Download\codeseg\code\ultralytics\engine\model.py", line 243, in predict return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\Download\codeseg\code\ultralytics\engine\predictor.py", line 196, in __call__ return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\20126\.conda\envs\pytorch\Lib\site-packages\torch\utils\_contextlib.py", line 35, in generator_context response = gen.send(None) ^^^^^^^^^^^^^^ File "D:\Download\codeseg\code\ultralytics\engine\predictor.py", line 263, in stream_inference self.results = self.postprocess(preds, im, im0s) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\Download\codeseg\code\ultralytics\models\yolo\segment\predict.py", line 42, in postprocess orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\Download\codeseg\code\ultralytics\utils\ops.py", line 785, in convert_torch2numpy_batch return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: Numpy is not available

from bertopic import BERTopic import numpy as np import pandas as pd from umap import UMAP from hdbscan import HDBSCAN from sklearn.feature_extraction.text import CountVectorizer from bertopic.vectorizers import ClassTfidfTransformer import re import nltk from nltk.corpus import stopwords from nltk.stem import WordNetLemmatizer from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS from nltk.tokenize import word_tokenize from wordcloud import WordCloud import matplotlib.pyplot as plt # 加载原始文本数据(仍需用于主题表示) df = pd.read_csv('hydrogen_storage_patents_preprocessed.csv', encoding='utf-8') sentences = df['AB'].tolist() print('文本条数: ', len(sentences)) print('预览第一条: ', sentences[0]) # 检查缺失值 print("缺失值数量:", df['AB'].isna().sum()) # 检查非字符串类型 non_str_mask = df['AB'].apply(lambda x: not isinstance(x, str)) print("非字符串样本:\n", df[non_str_mask]['AB'].head()) vectorizer_model = None # 1. 加载时间数据 df['AD'] = pd.to_datetime(df['AD']) # 检查是否有空值 missing_dates = df['AD'].isna().sum() if missing_dates > 0: print(f"日期列中存在 {missing_dates} 个空值,请检查数据!") else: print("日期列中没有空值。") # 从Date列提取年份 years = df['AD'].dt.year print(years) from sentence_transformers import SentenceTransformer # Step 1 - Extract embeddings embedding_model = SentenceTransformer("C:\\Users\\18267\\.cache\\huggingface\\hub\\models--sentence-transformers--all-mpnet-base-v2\\snapshots\\9a3225965996d404b775526de6dbfe85d3368642") embeddings = np.load('emb_last_326.npy') print(f"嵌入的形状: {embeddings.shape}") # Step 2 - Reduce dimensionality umap_model = UMAP(n_neighbors=7, n_components=10, min_dist=0.0, metric='cosine',random_state=42) # Step 3 - Cluster reduced embeddings hdbscan_model = HDBSCAN(min_samples=7, min_cluster_size=60,metric='euclidean', cluster_selection_method='eom', prediction_data=True) # Step 4 - Tokenize topics vectorizer_model = CountVectorizer() # Step 5 - Create topic representation ctfidf_model = ClassTfidfTransformer() # All steps together topic_model = BERTopic( embedding_model=embedding_model, # Step 1 - Extract embeddings umap_model=umap_model, # Step 2 - Reduce dimensionality hdbscan_model=hdbscan_model, # Step 3 - Cluster reduced embeddings vectorizer_model=vectorizer_model, # Step 4 - Tokenize topics ctfidf_model=ctfidf_model, # Step 5 - Extract topic words top_n_words=50 ) # 拟合模型 topics, probs = topic_model.fit_transform(documents=sentences, # 仍需提供文档用于主题词生成 embeddings=embeddings # 注入预计算嵌入) ) # 获取主题聚类信息 topic_info = topic_model.get_topic_info() print(topic_info) # 可视化主题词 topic_model.visualize_barchart() # 可视化主题分布 topic_model.visualize_topics() # 查看层级 hierarchical_topics = topic_model.hierarchical_topics(sentences) fig = topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics) # 隐藏小黑点 for trace in fig.data: if trace.mode == 'markers': trace.marker.opacity = 0 fig.show() new_topics = topic_model.reduce_outliers(sentences, topics, strategy="c-tf-idf",threshold=0.00) print(new_topics.count(-1), new_topics) topic_model.update_topics(docs=sentences, topics=new_topics,vectorizer_model=vectorizer_model,top_n_words=50) topic_info = topic_model.get_topic_info() print(topic_info)添加94-96行代码后继续运行显示RuntimeError: Numpy is not available

RuntimeError Traceback (most recent call last) Cell In[19], line 2 1 # Train the model on the COCO8 example dataset for 100 epochs ----> 2 results = model.train(data="C:\\Users\\asus\\Downloads\\coco8.yaml", epochs=100, imgsz=640) File D:\anaconda\envs\pytorch_env\lib\site-packages\ultralytics\engine\model.py:799, in Model.train(self, trainer, **kwargs) 796 self.model = self.trainer.model 798 self.trainer.hub_session = self.session # attach optional HUB session --> 799 self.trainer.train() 800 # Update model and cfg after training 801 if RANK in {-1, 0}: File D:\anaconda\envs\pytorch_env\lib\site-packages\ultralytics\engine\trainer.py:227, in BaseTrainer.train(self) 224 ddp_cleanup(self, str(file)) 226 else: --> 227 self._do_train(world_size) File D:\anaconda\envs\pytorch_env\lib\site-packages\ultralytics\engine\trainer.py:348, in BaseTrainer._do_train(self, world_size) 346 if world_size > 1: 347 self._setup_ddp(world_size) --> 348 self._setup_train(world_size) 350 nb = len(self.train_loader) # number of batches 351 nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations File D:\anaconda\envs\pytorch_env\lib\site-packages\ultralytics\engine\trainer.py:285, in BaseTrainer._setup_train(self, world_size) 283 if self.amp and RANK in {-1, 0}: # Single-GPU and DDP 284 callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them --> 285 self.amp = torch.tensor(check_amp(self.model), device=self.device) 286 callbacks.default_callbacks = callbacks_backup # restore callbacks 287 if RANK > -1 and world_size > 1: # DDP File D:\anaconda\envs\pytorch_env\lib\site-packages\ultralytics\utils\checks.py:782, in check_amp(model) 779 try: 780 from ultralytics import YOLO --> 782 assert amp_allclose(YOLO("yolo11n.pt"), im) 783 LOGGER.info(f"{prefix}checks passed ✅") 784 except ConnectionError: File D:\anaconda\envs\pytorch_env\lib\site-packages\ultralytics\utils\checks.py:770, in check_amp.<locals>.amp_allclose(m, im) 768 batch = [im] * 8 769 imgsz = max(256, int(model.stride.max() * 4)) # max stride P5-32 and P6-64 --> 770 a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # FP32 inference 771 with autocast(enabled=True): 772 b = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # AMP inference File D:\anaconda\envs\pytorch_env\lib\site-packages\ultralytics\engine\model.py:185, in Model.__call__(self, source, stream, **kwargs) 156 def __call__( 157 self, 158 source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None, 159 stream: bool = False, 160 **kwargs: Any, 161 ) -> list: 162 """ 163 Alias for the predict method, enabling the model instance to be callable for predictions. 164 (...) 183 ... print(f"Detected {len(r)} objects in image") 184 """ --> 185 return self.predict(source, stream, **kwargs) File D:\anaconda\envs\pytorch_env\lib\site-packages\ultralytics\engine\model.py:555, in Model.predict(self, source, stream, predictor, **kwargs) 553 if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models 554 self.predictor.set_prompts(prompts) --> 555 return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) File D:\anaconda\envs\pytorch_env\lib\site-packages\ultralytics\engine\predictor.py:227, in BasePredictor.__call__(self, source, model, stream, *args, **kwargs) 225 return self.stream_inference(source, model, *args, **kwargs) 226 else: --> 227 return list(self.stream_inference(source, model, *args, **kwargs)) File D:\anaconda\envs\pytorch_env\lib\site-packages\torch\autograd\grad_mode.py:43, in _DecoratorContextManager._wrap_generator.<locals>.generator_context(*args, **kwargs) 40 try: 41 # Issuing None to a generator fires it up 42 with self.clone(): ---> 43 response = gen.send(None) 45 while True: 46 try: 47 # Forward the response to our caller and get its next request File D:\anaconda\envs\pytorch_env\lib\site-packages\ultralytics\engine\predictor.py:326, in BasePredictor.stream_inference(self, source, model, *args, **kwargs) 324 # Preprocess 325 with profilers[0]: --> 326 im = self.preprocess(im0s) 328 # Inference 329 with profilers[1]: File D:\anaconda\envs\pytorch_env\lib\site-packages\ultralytics\engine\predictor.py:167, in BasePredictor.preprocess(self, im) 165 im = im.transpose((0, 3, 1, 2)) # BHWC to BCHW, (n, 3, h, w) 166 im = np.ascontiguousarray(im) # contiguous --> 167 im = torch.from_numpy(im) 169 im = im.to(self.device) 170 im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 RuntimeError: Numpy is not available

2025-03-25 09:13:17.582475: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT W load_tensorflow: 'import tensorflow' failed! Please use the following command to reinstall it: W load_tensorflow: pip3 install 'tensorflow>=1.12.0,<=2.14.0' W load_tensorflow: In addition, it is recommended that the TensorFlow version is consistent with the version of the pb model to avoid the model cannot be parsed. done --> Building model E build: The model has not been loaded, please load it first! done E export_rknn: RKNN model does not exist, please load & build model first! --> Init runtime environment E init_runtime: RKNN model does not exist, please load & build model first! Init runtime environment failed (Rknn) book@wyc_emb:~/rk3566/ai_tools/2018face$ pip3 tensorflow ERROR: unknown command "tensorflow" (Rknn) book@wyc_emb:~/rk3566/ai_tools/2018face$ pip3 show tensorflow Name: tensorflow Version: 2.13.0 Summary: TensorFlow is an open source machine learning framework for everyone. Home-page: https://www.tensorflow.org/ Author: Google Inc. Author-email: [email protected] License: Apache 2.0 Location: /home/book/anaconda3/envs/Rknn/lib/python3.8/site-packages Requires: absl-py, astunparse, flatbuffers, gast, google-pasta, grpcio, h5py, keras, libclang, numpy, opt-einsum, packaging, protobuf, setuptools, six, tensorboard, tensorflow-estimator, tensorflow-io-gcs-filesystem, termcolor, typing-extensions, wrapt Required-by: 为什么安装了合适的版本还是出错

(cvnets) D:\code\ml-cvnets-main>cvnets-train --common.config-file config/segmentation/pascal_voc/deeplabv3_mobilevitv2.yaml --common.results-loc deeplabv3_mobilevitv2_results/width_1_0_0 --common.override-kwargs model.classification.pretrained="LOCATION_OF_IMAGENET_1k_CHECKPOINT" NOTE: Redirects are currently not supported in Windows or MacOs. C:\Users\boardman\.conda\envs\cvnets\lib\site-packages\torchvision\models\detection\anchor_utils.py:63: UserWarning: Failed to initialize NumPy: module compiled against API version 0x10 but this version of numpy is 0xe (Triggered internally at ..\torch\csrc\utils\tensor_numpy.cpp:77.) device: torch.device = torch.device("cpu"), C:\Users\boardman\.conda\envs\cvnets\lib\site-packages\torchaudio\backend\utils.py:62: UserWarning: No audio backend is available. warnings.warn("No audio backend is available.") 2025-03-04 21:57:30 - DEBUG - Cannot load internal arguments, skipping. RuntimeError: module compiled against API version 0xf but this version of numpy is 0xe Traceback (most recent call last): File "C:\Users\boardman\.conda\envs\cvnets\lib\runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "C:\Users\boardman\.conda\envs\cvnets\lib\runpy.py", line 86, in _run_code exec(code, run_globals) File "C:\Users\boardman\.conda\envs\cvnets\Scripts\cvnets-train.exe\__main__.py", line 7, in <module> File "D:\code\ml-cvnets-main\main_train.py", line 193, in main_worker opts = get_training_arguments(args=args) File "D:\code\ml-cvnets-main\options\opts.py", line 332, in get_training_arguments parser = METRICS_REGISTRY.all_arguments(parser=parser) File "D:\code\ml-cvnets-main\utils\registry.py", line 180, in all_arguments self._load_all() File "D:\code\ml-cvnets-main\utils\registry.py", line 97, in _load_all import_modules_from_folder(dir_name, extra_roots=self.internal_dirs) File "D:\code\ml-cvnets-main\utils\import_utils.py", line 41, in import_modules_from_folder importlib.import_module(module_name) File "C:\Users\boardman\.conda\envs\cvnets\lib\importlib\__init__.py", line 126, in import_module return _bootstrap._gcd_import(name[level:], package, level) File "<frozen importlib._bootstrap>", line 1050, in _gcd_import File "<frozen importlib._bootstrap>", line 1027, in _find_and_load File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 688, in _load_unlocked File "<frozen importlib._bootstrap_external>", line 883, in exec_module File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed File "D:\code\ml-cvnets-main\metrics\average_precision.py", line 11, in <module> from sklearn.metrics import average_precision_score File "C:\Users\boardman\.conda\envs\cvnets\lib\site-packages\sklearn\__init__.py", line 82, in <module> from .base import clone File "C:\Users\boardman\.conda\envs\cvnets\lib\site-packages\sklearn\base.py", line 17, in <module> from .utils import _IS_32BIT File "C:\Users\boardman\.conda\envs\cvnets\lib\site-packages\sklearn\utils\__init__.py", line 17, in <module> from scipy.sparse import issparse File "C:\Users\boardman\.conda\envs\cvnets\lib\site-packages\scipy\sparse\__init__.py", line 267, in <module> from ._csr import * File "C:\Users\boardman\.conda\envs\cvnets\lib\site-packages\scipy\sparse\_csr.py", line 10, in <module> from ._sparsetools import (csr_tocsc, csr_tobsr, csr_count_blocks, ImportError: numpy.core.multiarray failed to import 请逐句进行分析

root@autodl-container-c85144bc1a-7cf0bfa7:~# # 验证ONNX Runtime版本 python -c "import onnxruntime as ort; print(ort.__version__)" # 需要 >=1.15.0 (2024年模型要求) # 检查CUDA可用性(如使用GPU) python -c "import onnxruntime; print(onnxruntime.get_device())" # 应输出:['CPU', 'GPU:0'] 或类似 # 若出现CUDA错误(参考引用[3]),添加环境变量 export CUDA_LAUNCH_BLOCKING=1 Traceback (most recent call last): File "<string>", line 1, in <module> ModuleNotFoundError: No module named 'onnxruntime' Traceback (most recent call last): File "<string>", line 1, in <module> ModuleNotFoundError: No module named 'onnxruntime' root@autodl-container-c85144bc1a-7cf0bfa7:~# # 示例正确配置 sherpa_onnx: model_dir: "/root/autodl-tmp/Open-LLM-VTuber/models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17" sample_rate: 16000 decoding_method: "greedy_search" bash: sherpa_onnx:: command not found bash: model_dir:: command not found bash: sample_rate:: command not found bash: decoding_method:: command not found root@autodl-container-c85144bc1a-7cf0bfa7:~# import onnxruntime as ort def load_model(model_path): try: sess = ort.InferenceSession(model_path) print(f"✅ 成功加载模型: {model_path}") print(f"输入形状: {sess.get_inputs()[0].shape}") print(f"输出形状: {sess.get_outputs()[0].shape}") except Exception as e: print(f"❌ 加载失败: {str(e)}") # 测试关键模型 load_model("/root/autodl-tmp/.../encoder.onnx") load_model("/root/autodl-tmp/.../decoder.onnx") bash: import: command not found bash: syntax error near unexpected token (' bash: try:: command not found bash: syntax error near unexpected token (' bash: syntax error near unexpected token f"✅ 成功加载模型: {model_path}"' bash: syntax error near unexpected token f"输入形状: {sess.get_inputs()[0].shape}"' bash: syntax error near unexpected token f"输出形状: {sess.get_outputs()[0].shape}"' bash: except: command not found bash: syntax error near unexpected token f"❌ 加载失败: {str(e)}"' bash: syntax error near unexpected token "/root/autodl-tmp/.../encoder.onnx"' bash: syntax error near unexpected token "/root/autodl-tmp/.../decoder.onnx"' root@autodl-container-c85144bc1a-7cf0bfa7:~/autodl-tmp/Open-LLM-VTuber# pip install netron python -m netron encoder.onnx Looking in indexes: http://mirrors.aliyun.com/pypi/simple Collecting netron Downloading http://mirrors.aliyun.com/pypi/packages/86/d5/8b72b3bcf717c765945014d41e28a3d3ef67e66965c65cc325a73dbfd097/netron-8.4.4-py3-none-any.whl (1.9 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.9/1.9 MB 4.4 MB/s eta 0:00:00 Installing collected packages: netron Successfully installed netron-8.4.4 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv /root/miniconda3/bin/python: No module named netron.__main__; 'netron' is a package and cannot be directly executed root@autodl-container-c85144bc1a-7cf0bfa7:~/autodl-tmp/Open-LLM-VTuber# pip install onnxruntime==1.15.1 # 2024年模型常用兼容版本 Looking in indexes: http://mirrors.aliyun.com/pypi/simple Collecting onnxruntime==1.15.1 Downloading http://mirrors.aliyun.com/pypi/packages/2f/e2/ced4e64433097cb14425098ce3c6200b83d226005e8c23ba5bac44c89ab9/onnxruntime-1.15.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.9 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.9/5.9 MB 4.1 MB/s eta 0:00:00 Requirement already satisfied: numpy>=1.21.6 in /root/miniconda3/lib/python3.10/site-packages (from onnxruntime==1.15.1) (1.26.4) Requirement already satisfied: packaging in /root/miniconda3/lib/python3.10/site-packages (from onnxruntime==1.15.1) (24.1) Collecting flatbuffers Downloading http://mirrors.aliyun.com/pypi/packages/b8/25/155f9f080d5e4bc0082edfda032ea2bc2b8fab3f4d25d46c1e9dd22a1a89/flatbuffers-25.2.10-py2.py3-none-any.whl (30 kB) Requirement already satisfied: sympy in /root/miniconda3/lib/python3.10/site-packages (from onnxruntime==1.15.1) (1.12.1) Requirement already satisfied: protobuf in /root/miniconda3/lib/python3.10/site-packages (from onnxruntime==1.15.1) (4.25.3) Collecting coloredlogs Downloading http://mirrors.aliyun.com/pypi/packages/a7/06/3d6badcf13db419e25b07041d9c7b4a2c331d3f4e7134445ec5df57714cd/coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 46.0/46.0 kB 2.5 MB/s eta 0:00:00 Collecting humanfriendly>=9.1 Downloading http://mirrors.aliyun.com/pypi/packages/f0/0f/310fb31e39e2d734ccaa2c0fb981ee41f7bd5056ce9bc29b2248bd569169/humanfriendly-10.0-py2.py3-none-any.whl (86 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 86.8/86.8 kB 3.8 MB/s eta 0:00:00 Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /root/miniconda3/lib/python3.10/site-packages (from sympy->onnxruntime==1.15.1) (1.3.0) Installing collected packages: flatbuffers, humanfriendly, coloredlogs, onnxruntime Successfully installed coloredlogs-15.0.1 flatbuffers-25.2.10 humanfriendly-10.0 onnxruntime-1.15.1 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv root@autodl-container-c85144bc1a-7cf0bfa7:~/autodl-tmp/Open-LLM-VTuber# ort_session = ort.InferenceSession("encoder.onnx", providers=['CPUExecutionProvider']) # 强制使用CPU bash: syntax error near unexpected token (' bash: syntax error near unexpected token )'

A module that was compiled using NumPy 1.x cannot be run in NumPy 2.1.1 as it may crash. To support both 1.x and 2.x versions of NumPy, modules must be compiled with NumPy 2.0. Some module may need to rebuild instead e.g. with 'pybind11>=2.12'. If you are a user of the module, the easiest solution will be to downgrade to 'numpy<2' or try to upgrade the affected module. We expect that some modules will need time to support NumPy 2. Traceback (most recent call last): File "F:\ultralytics-8.3.89\train.py", line 1, in <module> from ultralytics import YOLO File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\__init__.py", line 11, in <module> from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\models\__init__.py", line 3, in <module> from .fastsam import FastSAM File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\models\fastsam\model.py", line 5, in <module> from ultralytics.engine.model import Model File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\engine\model.py", line 12, in <module> from ultralytics.engine.results import Results File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\engine\results.py", line 15, in <module> from ultralytics.data.augment import LetterBox File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\data\__init__.py", line 3, in <module> from .base import BaseDataset File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\data\base.py", line 17, in <module> from ultralytics.data.utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\data\utils.py", line 18, in <module> from ultralytics.nn.autobackend import check_class_names File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\nn\autobackend.py", line 54, in <module> class AutoBackend(nn.Module): File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\nn\autobackend.py", line 89, in AutoBackend device=torch.device("cpu"), F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\nn\autobackend.py:89: UserWarning: Failed to initialize NumPy: _ARRAY_API not found (Triggered internally at ..\torch\csrc\utils\tensor_numpy.cpp:84.) device=torch.device("cpu"), New https://pypi.org/project/ultralytics/8.3.165 available 😃 Update with 'pip install -U ultralytics' Ultralytics 8.3.89 🚀 Python-3.11.13 torch-2.3.1+cu121 CUDA:0 (NVIDIA GeForce RTX 4080 Laptop GPU, 12282MiB) engine\trainer: task=detect, mode=train, model=yolo11n.pt, data=data.yaml, epochs=200, time=None, patience=100, batch=20, imgsz=640, save=True, save_period=-1, cache=False, device=0, workers=8, project=None, name=train2, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, embed=None, show=False, save_frames=False, save_txt=False, save_conf=False, save_crop=False, show_labels=True, show_conf=True, show_boxes=True, line_width=None, format=torchscript, keras=False, optimize=False, int8=False, dynamic=False, simplify=True, opset=None, workspace=None, nms=False, lr0=0.01, lrf=0.01, momentum=0.937, weight_decay=0.0005, warmup_epochs=3.0, warmup_momentum=0.8, warmup_bias_lr=0.1, box=7.5, cls=0.5, dfl=1.5, pose=12.0, kobj=1.0, nbs=64, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, flipud=0.0, fliplr=0.5, bgr=0.0, mosaic=1.0, mixup=0.0, copy_paste=0.0, copy_paste_mode=flip, auto_augment=randaugment, erasing=0.4, crop_fraction=1.0, cfg=None, tracker=botsort.yaml, save_dir=runs\detect\train2 Overriding model.yaml nc=80 with nc=1 from n params module arguments 0 -1 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2] 1 -1 1 4672 ultralytics.nn.modules.conv.Conv [16, 32, 3, 2] 2 -1 1 6640 ultralytics.nn.modules.block.C3k2 [32, 64, 1, False, 0.25] 3 -1 1 36992 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2] 4 -1 1 26080 ultralytics.nn.modules.block.C3k2 [64, 128, 1, False, 0.25] 5 -1 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2] 6 -1 1 87040 ultralytics.nn.modules.block.C3k2 [128, 128, 1, True] 7 -1 1 295424 ultralytics.nn.modules.conv.Conv [128, 256, 3, 2] 8 -1 1 346112 ultralytics.nn.modules.block.C3k2 [256, 256, 1, True] 9 -1 1 164608 ultralytics.nn.modules.block.SPPF [256, 256, 5] 10 -1 1 249728 ultralytics.nn.modules.block.C2PSA [256, 256, 1] 11 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] 12 [-1, 6] 1 0 ultralytics.nn.modules.conv.Concat [1] 13 -1 1 111296 ultralytics.nn.modules.block.C3k2 [384, 128, 1, False] 14 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] 15 [-1, 4] 1 0 ultralytics.nn.modules.conv.Concat [1] 16 -1 1 32096 ultralytics.nn.modules.block.C3k2 [256, 64, 1, False] 17 -1 1 36992 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2] 18 [-1, 13] 1 0 ultralytics.nn.modules.conv.Concat [1] 19 -1 1 86720 ultralytics.nn.modules.block.C3k2 [192, 128, 1, False] 20 -1 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2] 21 [-1, 10] 1 0 ultralytics.nn.modules.conv.Concat [1] 22 -1 1 378880 ultralytics.nn.modules.block.C3k2 [384, 256, 1, True] 23 [16, 19, 22] 1 430867 ultralytics.nn.modules.head.Detect [1, [64, 128, 256]] YOLO11n summary: 181 layers, 2,590,035 parameters, 2,590,019 gradients, 6.4 GFLOPs Transferred 448/499 items from pretrained weights Freezing layer 'model.23.dfl.conv.weight' AMP: running Automatic Mixed Precision (AMP) checks... Traceback (most recent call last): File "F:\ultralytics-8.3.89\train.py", line 20, in <module> main() File "F:\ultralytics-8.3.89\train.py", line 8, in main results = model.train( ^^^^^^^^^^^^ File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\engine\model.py", line 810, in train self.trainer.train() File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\engine\trainer.py", line 208, in train self._do_train(world_size) File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\engine\trainer.py", line 323, in _do_train self._setup_train(world_size) File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\engine\trainer.py", line 265, in _setup_train self.amp = torch.tensor(check_amp(self.model), device=self.device) ^^^^^^^^^^^^^^^^^^^^^ File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\utils\checks.py", line 698, in check_amp assert amp_allclose(YOLO("yolo11n.pt"), im) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\utils\checks.py", line 686, in amp_allclose a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # FP32 inference ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\engine\model.py", line 182, in __call__ return self.predict(source, stream, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\engine\model.py", line 560, in predict return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\engine\predictor.py", line 175, in __call__ return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\anaconda\envs\model\Lib\site-packages\torch\utils\_contextlib.py", line 35, in generator_context response = gen.send(None) ^^^^^^^^^^^^^^ File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\engine\predictor.py", line 257, in stream_inference im = self.preprocess(im0s) ^^^^^^^^^^^^^^^^^^^^^ File "F:\ultralytics-8.3.89\ultralytics-8.1.0\ultralytics\engine\predictor.py", line 130, in preprocess im = torch.from_numpy(im) ^^^^^^^^^^^^^^^^^^^^ RuntimeError: Numpy is not available

note: This error originates from a subprocess, and is likely not a problem with pip. ERROR: Failed building wheel for vllm Failed to build vllm ERROR: Could not build wheels for vllm, which is required to install pyproject.toml-based projects WARNING: You are using pip version 22.0.4; however, version 25.0.1 is available. You should consider upgrading via the 'E:\venv\Scripts\python.exe -m pip install --upgrade pip' command.根据你的环境,vllm-cpu 包确实不存在于 PyPI。请使用以下修正方案: 正确安装步骤(Windows 原生环境) powershell # 1. 确保已安装 CPU 版 PyTorch pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu # 2. 从源码安装 vLLM(跳过问题配置文件) git clone https://github.com/vllm-project/vllm.git cd vllm # 3. 删除导致路径问题的配置文件 Remove-Item vllm/model_executor/layers/fused_moe/configs -Recurse -Force # 4. 设置环境变量并安装 $env:VLLM_NO_GPU = 1 pip install . --no-deps # 5. 手动补装关键依赖 pip install accelerate>=0.25.0 transformers>=4.40.0 numpy<1.24 psutil 替代方案:使用 WSL(推荐) powershell # 1. 启用 WSL(需管理员权限) wsl --install -d Ubuntu # 2. 在 WSL 中执行: wsl sudo apt update && sudo apt install python3-pip pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu VLLM_NO_GPU=1 pip install vllm --no-deps pip install accelerate>=0.25.0 验证安装 powershell # 在代码中强制使用 CPU 模式 python -c "import os; os.environ['VLLM_NO_GPU']='1'; from vllm import LLM; llm = LLM('gpt2', enforce_eager=True)" 关键说明 路径问题根源: Windows 对包含 = [ ] 等符号的文件路径支持差 通过删除 vllm/model_executor/layers/fused_moe/configs 规避问题 永久环境变量配置: powershell Add-Content -Path $PROFILE -Value "$env:VLLM_NO_GPU=1" 如果仍遇到问题,建议关注 vLLM 的 Windows 支持进展:

import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader import h5py import numpy as np import os from tqdm import tqdm import json # 动态融合模块 class DynamicFusion(nn.Module): def __init__(self, input_dims, hidden_dim=128, lambda_val=0.5): """ 动态融合模块 参数: input_dims (dict): 各模态的输入维度 {'audio': dim_a, 'visual': dim_v, 'text': dim_t} hidden_dim (int): 公共特征空间的维度 lambda_val (float): 多模态提升能力与鉴别能力的平衡因子 """ super(DynamicFusion, self).__init__() self.lambda_val = lambda_val self.modalities = list(input_dims.keys()) self.num_modalities = len(self.modalities) # 模态投影层 (将不同模态映射到公共空间) self.projections = nn.ModuleDict() for modality, dim in input_dims.items(): self.projections[modality] = nn.Sequential( nn.Linear(dim, hidden_dim), nn.ReLU(), nn.LayerNorm(hidden_dim) ) # 重要性评估器 self.importance_evaluator = nn.Sequential( nn.Linear(hidden_dim * self.num_modalities, 256), nn.ReLU(), nn.Linear(256, self.num_modalities) ) # 预测器 self.predictor = nn.Sequential( nn.Linear(hidden_dim, 128), nn.ReLU(), nn.Linear(128, 1) # 假设是二分类任务 ) def forward(self, features, labels=None, return_importance=False): """ 前向传播 参数: features (dict): 各模态特征 {'audio': tensor, 'visual': tensor, 'text': tensor} labels (Tensor): 真实标签 (仅在训练时使用) return_importance (bool): 是否返回模态重要性 返回: fused_output: 融合后的预测结果 importance_scores: 模态重要性分数 (如果return_importance=True) """ # 1. 投影到公共特征空间 projected = {} for modality in self.modalities: projected[modality] = self.projections[modality](features[modality]) # 2. 计算全模态融合预测 all_features = torch.cat([projected[m] for m in self.modalities], dim=1) importance_logits = self.importance_evaluator(all_features) importance_weights = F.softmax(importance_logits, dim=1) # 3. 加权融合 fused_feature = torch.zeros_like(projected[self.modalities[0]]) for i, modality in enumerate(self.modalities): fused_feature += importance_weights[:, i].unsqueeze(1) * projected[modality] fused_output = self.predictor(fused_feature) if not self.training or labels is None: if return_importance: return fused_output, importance_weights return fused_output # 4. 训练时计算监督信号 # 4.1 计算全模态损失 full_loss = F.binary_cross_entropy_with_logits(fused_output, labels, reduction='none') # 4.2 计算单模态损失 single_losses = {} for modality in self.modalities: single_pred = self.predictor(projected[modality]) single_losses[modality] = F.binary_cross_entropy_with_logits(single_pred, labels, reduction='none') # 4.3 计算移除各模态后的损失 remove_losses = {} for modality_to_remove in self.modalities: # 重归一化剩余模态的权重 remaining_modalities = [m for m in self.modalities if m != modality_to_remove] remaining_indices = [self.modalities.index(m) for m in remaining_modalities] # 计算剩余模态的归一化权重 remaining_logits = importance_logits[:, remaining_indices] remaining_weights = F.softmax(remaining_logits, dim=1) # 剩余模态的融合 fused_remove = torch.zeros_like(fused_feature) for i, modality in enumerate(remaining_modalities): idx = remaining_modalities.index(modality) fused_remove += remaining_weights[:, idx].unsqueeze(1) * projected[modality] remove_pred = self.predictor(fused_remove) remove_losses[modality_to_remove] = F.binary_cross_entropy_with_logits( remove_pred, labels, reduction='none' ) # 4.4 计算监督信号 - 模态重要性标签 importance_labels = {} for modality in self.modalities: L_m = single_losses[modality] L_full = full_loss L_remove = remove_losses[modality] # 计算多模态提升能力 multimodal_boost = L_full - L_remove # 公式(2-38): I_m = -[λ·(L - L̂) + (1-λ)·L_m] I_m = -(self.lambda_val * multimodal_boost + (1 - self.lambda_val) * L_m) importance_labels[modality] = I_m.detach() # 分离计算图 # 4.5 重要性监督损失 (使用排序损失) importance_loss = 0 for i, modality_i in enumerate(self.modalities): for j, modality_j in enumerate(self.modalities): if i >= j: continue # 获取预测的重要性分数差异 score_diff = importance_logits[:, i] - importance_logits[:, j] # 获取真实重要性标签差异 label_diff = importance_labels[modality_i] - importance_labels[modality_j] # 使用hinge loss进行排序优化 loss_pair = F.relu(-label_diff * score_diff + 0.1) importance_loss += loss_pair.mean() # 4.6 总损失 = 主任务损失 + 重要性监督损失 main_loss = full_loss.mean() total_loss = main_loss + importance_loss if return_importance: return total_loss, importance_weights return total_loss # 数据集类 - 修改为使用带噪声的特征文件和单个标签文件 class NoisyMultimodalDataset(Dataset): def __init__(self, base_dir, split='train'): """ 带噪声的多模态数据集 参数: base_dir (str): 数据集基础目录 split (str): 数据集分割 (train/val/test) """ self.base_dir = base_dir self.split = split self.sample_ids = [] # 定义模态文件路径 - 使用带噪声的文件 self.modality_files = { 'audio': os.path.join(base_dir, "A", "acoustic_noisy.h5"), 'visual': os.path.join(base_dir, "V", "visual_noisy.h5"), 'text': os.path.join(base_dir, "L", "deberta-v3-large_noisy.h5") } # 检查文件存在性 self.available_modalities = [] for modality, file_path in self.modality_files.items(): if os.path.exists(file_path): self.available_modalities.append(modality) print(f"找到模态文件: {modality} -> {file_path}") else: print(f"警告: 模态文件不存在: {file_path}") if not self.available_modalities: raise FileNotFoundError("未找到任何模态文件!") # 验证文件存在性并获取样本ID sample_ids_sets = [] for modality in self.available_modalities: file_path = self.modality_files[modality] try: with h5py.File(file_path, 'r') as f: sample_ids = set(f.keys()) sample_ids_sets.append(sample_ids) print(f"模态 {modality}: 找到 {len(sample_ids)} 个样本") except Exception as e: print(f"加载模态 {modality} 文件时出错: {str(e)}") sample_ids_sets.append(set()) # 取所有模态样本ID的交集 common_ids = set.intersection(*sample_ids_sets) self.sample_ids = sorted(list(common_ids)) # 加载标签文件 label_file = os.path.join(base_dir, "labels", "all_labels.npy") if not os.path.exists(label_file): # 尝试在基础目录下直接查找 label_file = os.path.join(base_dir, "all_labels.npy") if not os.path.exists(label_file): raise FileNotFoundError(f"标签文件不存在: {os.path.join(base_dir, 'labels', 'all_labels.npy')} 或 {label_file}") try: all_labels = np.load(label_file) print(f"加载标签文件: {label_file}, 总标签数: {len(all_labels)}") except Exception as e: raise IOError(f"加载标签文件时出错: {str(e)}") # 创建样本ID到标签索引的映射 id_to_index = {} try: # 使用第一个模态文件作为参考顺序 first_modality = self.available_modalities[0] with h5py.File(self.modality_files[first_modality], 'r') as f: all_ids_in_file = list(f.keys()) for idx, sample_id in enumerate(all_ids_in_file): id_to_index[sample_id] = idx print(f"使用 {first_modality} 模态的顺序作为标签映射参考") except Exception as e: raise RuntimeError(f"创建ID映射时出错: {str(e)}") # 获取当前分割样本的标签 self.labels = [] valid_sample_ids = [] missing_samples = 0 for sample_id in self.sample_ids: if sample_id in id_to_index: idx = id_to_index[sample_id] if idx < len(all_labels): self.labels.append(all_labels[idx]) valid_sample_ids.append(sample_id) else: print(f"警告: 样本 {sample_id} 的索引 {idx} 超出标签范围 (标签长度={len(all_labels)})") missing_samples += 1 else: print(f"警告: 样本 {sample_id} 未在标签映射中找到") missing_samples += 1 if missing_samples > 0: print(f"警告: {missing_samples} 个样本缺少标签") if not valid_sample_ids: raise ValueError("没有有效的样本ID与标签匹配") self.sample_ids = valid_sample_ids self.labels = np.array(self.labels, dtype=np.float32) print(f"加载 {split} 数据集: {len(self)} 个样本, {len(self.available_modalities)} 个模态") print(f"样本示例: {self.sample_ids[:3]}, 标签示例: {self.labels[:3]}") def __len__(self): return len(self.sample_ids) def __getitem__(self, idx): sample_id = self.sample_ids[idx] features = {} # 加载各模态特征 for modality in self.available_modalities: file_path = self.modality_files[modality] try: with h5py.File(file_path, 'r') as f: features[modality] = f[sample_id][:].astype(np.float32) except Exception as e: print(f"加载样本 {sample_id} 的 {modality} 特征时出错: {str(e)}") # 如果某个特征加载失败,使用零填充 if modality == 'audio': features[modality] = np.zeros(74, dtype=np.float32) elif modality == 'visual': features[modality] = np.zeros(47, dtype=np.float32) elif modality == 'text': features[modality] = np.zeros(1024, dtype=np.float32) else: features[modality] = np.zeros(128, dtype=np.float32) # 默认 label = self.labels[idx] return features, label # 训练函数 def train(model, dataloader, optimizer, device): model.train() total_loss = 0 for features, labels in tqdm(dataloader, desc="训练中"): # 将数据移至设备 for modality in features.keys(): features[modality] = features[modality].to(device) labels = labels.to(device).unsqueeze(1) # 添加维度以匹配输出 # 前向传播 loss = model(features, labels) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() avg_loss = total_loss / len(dataloader) return avg_loss # 验证函数 def validate(model, dataloader, device): model.eval() total_loss = 0 all_preds = [] all_labels = [] with torch.no_grad(): for features, labels in tqdm(dataloader, desc="验证中"): # 将数据移至设备 for modality in features.keys(): features[modality] = features[modality].to(device) labels = labels.to(device).unsqueeze(1) # 前向传播 outputs = model(features, labels) # 计算损失 loss = F.binary_cross_entropy_with_logits(outputs, labels) total_loss += loss.item() # 收集预测结果和标签 preds = torch.sigmoid(outputs) all_preds.append(preds.cpu()) all_labels.append(labels.cpu()) # 计算指标 all_preds = torch.cat(all_preds) all_labels = torch.cat(all_labels) # 计算准确率 pred_labels = (all_preds > 0.5).float() accuracy = (pred_labels == all_labels).float().mean().item() avg_loss = total_loss / len(dataloader) return avg_loss, accuracy # 主函数 def main(): # 配置 device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") base_dir = "/home/msj_team/data/code/CIF-MMIN/data/dataset/mositext" # 检查基础目录是否存在 if not os.path.exists(base_dir): print(f"错误: 基础目录不存在: {base_dir}") print("请检查路径是否正确,或确保数据集已下载并放置在正确位置") return # 检查标签目录是否存在 label_dir = os.path.join(base_dir, "labels") if not os.path.exists(label_dir): print(f"警告: 标签目录不存在: {label_dir}") print("尝试在基础目录下查找标签文件...") # 检查标签文件是否存在 label_file = os.path.join(label_dir, "all_labels.npy") if not os.path.exists(label_file): label_file = os.path.join(base_dir, "all_labels.npy") if not os.path.exists(label_file): print(f"错误: 标签文件不存在: {os.path.join(label_dir, 'all_labels.npy')} 或 {label_file}") return print(f"使用标签文件: {label_file}") # 根据实际特征维度设置 input_dims = { 'audio': 74, # 音频特征维度 'visual': 47, # 视觉特征维度 'text': 1024 # 文本特征维度 } # 训练参数 batch_size = 32 learning_rate = 1e-4 num_epochs = 20 lambda_val = 0.7 # 平衡因子 try: # 创建数据集和数据加载器 print("\n初始化训练数据集...") train_dataset = NoisyMultimodalDataset(base_dir, split='train') print("\n初始化验证数据集...") val_dataset = NoisyMultimodalDataset(base_dir, split='val') print(f"\n训练集样本数: {len(train_dataset)}") print(f"验证集样本数: {len(val_dataset)}") # 检查数据集是否为空 if len(train_dataset) == 0: raise ValueError("训练数据集为空!") if len(val_dataset) == 0: raise ValueError("验证数据集为空!") train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4) # 初始化模型 model = DynamicFusion(input_dims, lambda_val=lambda_val).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # 训练循环 best_val_accuracy = 0.0 print("\n开始训练...") for epoch in range(num_epochs): print(f"\nEpoch {epoch+1}/{num_epochs}") train_loss = train(model, train_loader, optimizer, device) val_loss, val_accuracy = validate(model, val_loader, device) print(f"训练损失: {train_loss:.4f}, 验证损失: {val_loss:.4f}, 验证准确率: {val_accuracy:.4f}") # 保存最佳模型 if val_accuracy > best_val_accuracy: best_val_accuracy = val_accuracy torch.save(model.state_dict(), "best_model_noisy.pth") print(f"保存最佳模型 (验证准确率: {val_accuracy:.4f})") print("\n训练完成!") print(f"最终验证准确率: {best_val_accuracy:.4f}") except Exception as e: print(f"\n发生错误: {str(e)}") import traceback traceback.print_exc() if __name__ == "__main__": main() 这是我的代码,要怎样修改

检查代码是否可运行,是否高效,是否可CPUimport sys import os import json import time import wave import numpy as np import pandas as pd import matplotlib.pyplot as plt import soundfile as sf from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QLineEdit, QTextEdit, QFileDialog, QProgressBar, QGroupBox, QComboBox, QCheckBox, QMessageBox) from PyQt5.QtCore import QThread, pyqtSignal from pydub import AudioSegment from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification import whisper from pyannote.audio import Pipeline from docx import Document from docx.shared import Inches import librosa import tempfile from collections import defaultdict import re from concurrent.futures import ThreadPoolExecutor, as_completed import torch from torch.cuda import is_available as cuda_available import logging import gc # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 全局模型缓存 MODEL_CACHE = {} class AnalysisThread(QThread): progress = pyqtSignal(int) message = pyqtSignal(str) analysis_complete = pyqtSignal(dict) error = pyqtSignal(str) def __init__(self, audio_files, keyword_file, whisper_model_path, pyannote_model_path, emotion_model_path): super().__init__() self.audio_files = audio_files self.keyword_file = keyword_file self.whisper_model_path = whisper_model_path self.pyannote_model_path = pyannote_model_path self.emotion_model_path = emotion_model_path self.running = True self.cached_models = {} self.temp_files = [] # 用于管理临时文件 self.lock = torch.multiprocessing.Lock() # 用于模型加载的锁 def run(self): try: # 加载关键词 self.message.emit("正在加载关键词...") keywords = self.load_keywords() # 预加载模型 self.message.emit("正在预加载模型...") self.preload_models() results = [] total_files = len(self.audio_files) for idx, audio_file in enumerate(self.audio_files): if not self.running: self.message.emit("分析已停止") return self.message.emit(f"正在处理文件: {os.path.basename(audio_file)} ({idx + 1}/{total_files})") file_result = self.analyze_file(audio_file, keywords) if file_result: results.append(file_result) # 定期清理内存 if idx % 5 == 0: gc.collect() torch.cuda.empty_cache() if cuda_available() else None self.progress.emit(int((idx + 1) / total_files * 100)) self.analysis_complete.emit({"results": results, "keywords": keywords}) self.message.emit("分析完成!") except Exception as e: import traceback error_msg = f"分析过程中发生错误: {str(e)}\n{traceback.format_exc()}" self.error.emit(error_msg) logger.error(error_msg) finally: # 清理临时文件 self.cleanup_temp_files() def cleanup_temp_files(self): """清理所有临时文件""" for temp_file in self.temp_files: if os.path.exists(temp_file): try: os.unlink(temp_file) except Exception as e: logger.warning(f"删除临时文件失败: {temp_file}, 原因: {str(e)}") def preload_models(self): """预加载所有模型到缓存(添加线程安全)""" global MODEL_CACHE # 使用锁确保线程安全 with self.lock: # 检查全局缓存是否已加载模型 if 'whisper' in MODEL_CACHE and 'pyannote' in MODEL_CACHE and 'emotion_classifier' in MODEL_CACHE: self.cached_models = MODEL_CACHE self.message.emit("使用缓存的模型") return self.cached_models = {} try: # 加载语音识别模型 if 'whisper' not in MODEL_CACHE: self.message.emit("正在加载语音识别模型...") MODEL_CACHE['whisper'] = whisper.load_model( self.whisper_model_path, device="cuda" if cuda_available() else "cpu" ) self.cached_models['whisper'] = MODEL_CACHE['whisper'] # 加载说话人分离模型 if 'pyannote' not in MODEL_CACHE: self.message.emit("正在加载说话人分离模型...") MODEL_CACHE['pyannote'] = Pipeline.from_pretrained( self.pyannote_model_path, use_auth_token=True ) self.cached_models['pyannote'] = MODEL_CACHE['pyannote'] # 加载情感分析模型 if 'emotion_classifier' not in MODEL_CACHE: self.message.emit("正在加载情感分析模型...") device = 0 if cuda_available() else -1 tokenizer = AutoTokenizer.from_pretrained(self.emotion_model_path) model = AutoModelForSequenceClassification.from_pretrained(self.emotion_model_path) # 尝试使用半精度浮点数减少内存占用 try: if device != -1: model = model.half() except Exception: pass # 如果失败则继续使用全精度 MODEL_CACHE['emotion_classifier'] = pipeline( "text-classification", model=model, tokenizer=tokenizer, device=device ) self.cached_models['emotion_classifier'] = MODEL_CACHE['emotion_classifier'] except Exception as e: raise Exception(f"模型加载失败: {str(e)}") def analyze_file(self, audio_file, keywords): """分析单个音频文件(优化内存使用)""" try: # 确保音频为WAV格式 wav_file, is_temp = self.convert_to_wav(audio_file) if is_temp: self.temp_files.append(wav_file) # 获取音频信息 duration, sample_rate, channels = self.get_audio_info(wav_file) # 说话人分离 - 使用较小的音频片段处理大文件 diarization = self.process_diarization(wav_file, duration) # 识别客服和客户 agent_segments, customer_segments = self.identify_speakers(wav_file, diarization, keywords['opening']) # 并行处理客服和客户音频 agent_result, customer_result = {}, {} with ThreadPoolExecutor(max_workers=2) as executor: agent_future = executor.submit( self.process_speaker_audio, wav_file, agent_segments, "客服" ) customer_future = executor.submit( self.process_speaker_audio, wav_file, customer_segments, "客户" ) agent_result = agent_future.result() customer_result = customer_future.result() # 情感分析 - 批处理提高效率 agent_emotion, customer_emotion = self.analyze_emotions( [agent_result.get('text', ''), customer_result.get('text', '')] ) # 服务规范检查 opening_check = self.check_opening(agent_result.get('text', ''), keywords['opening']) closing_check = self.check_closing(agent_result.get('text', ''), keywords['closing']) forbidden_check = self.check_forbidden(agent_result.get('text', ''), keywords['forbidden']) # 沟通技巧分析 speech_rate = self.analyze_speech_rate(agent_result.get('segments', [])) volume_analysis = self.analyze_volume(wav_file, agent_segments, sample_rate) # 问题解决率分析 resolution_rate = self.analyze_resolution( agent_result.get('text', ''), customer_result.get('text', ''), keywords['resolution'] ) return { "file_name": os.path.basename(audio_file), "duration": duration, "agent_text": agent_result.get('text', ''), "customer_text": customer_result.get('text', ''), "opening_check": opening_check, "closing_check": closing_check, "forbidden_check": forbidden_check, "agent_emotion": agent_emotion, "customer_emotion": customer_emotion, "speech_rate": speech_rate, "volume_mean": volume_analysis.get('mean', -60), "volume_std": volume_analysis.get('std', 0), "resolution_rate": resolution_rate } except Exception as e: error_msg = f"处理文件 {os.path.basename(audio_file)} 时出错: {str(e)}" self.error.emit(error_msg) logger.error(error_msg, exc_info=True) return None finally: # 清理临时文件 if is_temp and os.path.exists(wav_file): try: os.unlink(wav_file) except Exception: pass def process_diarization(self, wav_file, duration): """分块处理说话人分离,避免大文件内存溢出""" # 对于短音频直接处理 if duration <= 600: # 10分钟以下 return self.cached_models['pyannote'](wav_file) # 对于长音频分块处理 self.message.emit(f"音频较长({duration:.1f}秒),将分块处理...") diarization_result = [] chunk_size = 300 # 5分钟块 for start in range(0, int(duration), chunk_size): if not self.running: return [] end = min(start + chunk_size, duration) self.message.emit(f"处理片段: {start}-{end}秒") # 提取音频片段 with tempfile.NamedTemporaryFile(suffix='.wav') as tmpfile: self.extract_audio_segment(wav_file, start, end, tmpfile.name) segment_diarization = self.cached_models['pyannote'](tmpfile.name) # 调整时间偏移 for segment, _, speaker in segment_diarization.itertracks(yield_label=True): diarization_result.append(( segment.start + start, segment.end + start, speaker )) return diarization_result def extract_audio_segment(self, input_file, start_sec, end_sec, output_file): """提取音频片段""" audio = AudioSegment.from_wav(input_file) start_ms = int(start_sec * 1000) end_ms = int(end_sec * 1000) segment = audio[start_ms:end_ms] segment.export(output_file, format="wav") def process_speaker_audio(self, wav_file, segments, speaker_type): """处理说话人音频(优化内存使用)""" if not segments: return {'text': "", 'segments': []} text = "" segment_details = [] whisper_model = self.cached_models['whisper'] # 处理每个片段 for idx, (start, end) in enumerate(segments): if not self.running: break # 每处理5个片段报告一次进度 if idx % 5 == 0: self.message.emit(f"{speaker_type}: 处理片段 {idx+1}/{len(segments)}") duration = end - start segment_text = self.transcribe_audio_segment(wav_file, start, end, whisper_model) segment_details.append({ 'start': start, 'end': end, 'duration': duration, 'text': segment_text }) text += segment_text + " " return { 'text': text.strip(), 'segments': segment_details } def identify_speakers(self, wav_file, diarization, opening_keywords): """ 改进的客服识别方法 1. 检查前三个片段是否有开场白关键词 2. 如果片段不足三个,则检查所有存在的片段 3. 如果无法确定客服,则默认第二个说话人是客服 """ if not diarization: return [], [] speaker_segments = defaultdict(list) speaker_first_occurrence = {} # 记录每个说话人的首次出现时间 # 收集所有说话人片段并记录首次出现时间 for item in diarization: if len(item) == 3: # 来自分块处理的结果 start, end, speaker = item else: # 来自pyannote的直接结果 segment, _, speaker = item start, end = segment.start, segment.end speaker_segments[speaker].append((start, end)) if speaker not in speaker_first_occurrence or start < speaker_first_occurrence[speaker]: speaker_first_occurrence[speaker] = start # 如果没有说话人 if not speaker_segments: return [], [] # 如果只有一个说话人 if len(speaker_segments) == 1: speaker = list(speaker_segments.keys())[0] return speaker_segments[speaker], [] # 计算每个说话人的开场白得分 speaker_scores = {} whisper_model = self.cached_models['whisper'] for speaker, segments in speaker_segments.items(): score = 0 # 检查前三个片段(如果存在) check_segments = segments[:3] # 最多取前三个片段 for start, end in check_segments: # 转录片段 text = self.transcribe_audio_segment(wav_file, start, end, whisper_model) # 检查开场白关键词 for keyword in opening_keywords: if keyword and keyword in text: score += 1 break # 找到一个关键词就加分并跳出循环 speaker_scores[speaker] = score # 尝试找出得分最高的说话人 max_score = max(speaker_scores.values()) max_speakers = [spk for spk, score in speaker_scores.items() if score == max_score] # 如果有唯一最高分说话人,作为客服 if len(max_speakers) == 1: agent_speaker = max_speakers[0] else: # 无法通过开场白确定客服时,默认第二个说话人是客服 # 按首次出现时间排序 sorted_speakers = sorted(speaker_first_occurrence.items(), key=lambda x: x[1]) # 确保至少有两个说话人 if len(sorted_speakers) >= 2: # 取时间上第二个出现的说话人 agent_speaker = sorted_speakers[1][0] else: # 如果只有一个说话人(理论上不会进入此分支,但安全处理) agent_speaker = sorted_speakers[0][0] # 分离客服和客户片段 agent_segments = speaker_segments[agent_speaker] customer_segments = [] for speaker, segments in speaker_segments.items(): if speaker != agent_speaker: customer_segments.extend(segments) return agent_segments, customer_segments def load_keywords(self): """从Excel文件加载关键词(增强健壮性)""" try: df = pd.read_excel(self.keyword_file) # 确保列存在 columns = ['opening', 'closing', 'forbidden', 'resolution'] for col in columns: if col not in df.columns: raise ValueError(f"关键词文件缺少必要列: {col}") keywords = { "opening": [str(k).strip() for k in df['opening'].dropna().tolist() if str(k).strip()], "closing": [str(k).strip() for k in df['closing'].dropna().tolist() if str(k).strip()], "forbidden": [str(k).strip() for k in df['forbidden'].dropna().tolist() if str(k).strip()], "resolution": [str(k).strip() for k in df['resolution'].dropna().tolist() if str(k).strip()] } # 检查是否有足够的关键词 if not any(keywords.values()): raise ValueError("关键词文件中没有找到有效关键词") return keywords except Exception as e: raise Exception(f"加载关键词文件失败: {str(e)}") def convert_to_wav(self, audio_file): """将音频文件转换为WAV格式(增强健壮性)""" try: if not os.path.exists(audio_file): raise FileNotFoundError(f"音频文件不存在: {audio_file}") if audio_file.lower().endswith('.wav'): return audio_file, False # 使用临时文件避免磁盘IO with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmpfile: output_file = tmpfile.name audio = AudioSegment.from_file(audio_file) audio.export(output_file, format='wav') return output_file, True except Exception as e: raise Exception(f"音频转换失败: {str(e)}") def get_audio_info(self, wav_file): """获取音频文件信息(增强健壮性)""" try: if not os.path.exists(wav_file): raise FileNotFoundError(f"音频文件不存在: {wav_file}") # 使用soundfile获取更可靠的信息 with sf.SoundFile(wav_file) as f: duration = len(f) / f.samplerate sample_rate = f.samplerate channels = f.channels return duration, sample_rate, channels except Exception as e: raise Exception(f"获取音频信息失败: {str(e)}") def transcribe_audio_segment(self, wav_file, start, end, model): """转录单个音频片段 - 优化内存使用""" # 使用pydub加载音频 audio = AudioSegment.from_wav(wav_file) # 转换为毫秒 start_ms = int(start * 1000) end_ms = int(end * 1000) segment_audio = audio[start_ms:end_ms] # 使用临时文件 with tempfile.NamedTemporaryFile(suffix='.wav') as tmpfile: segment_audio.export(tmpfile.name, format="wav") try: result = model.transcribe( tmpfile.name, fp16=cuda_available() # 使用FP16加速(如果可用) ) return result['text'] except RuntimeError as e: if "out of memory" in str(e).lower(): # 尝试释放内存后重试 torch.cuda.empty_cache() gc.collect() result = model.transcribe( tmpfile.name, fp16=cuda_available() ) return result['text'] raise def analyze_emotions(self, texts): """批量分析文本情感(提高效率)""" if not any(t.strip() for t in texts): return [{"label": "中性", "score": 0.0} for _ in texts] # 截断长文本以提高性能 processed_texts = [t[:500] if len(t) > 500 else t for t in texts] # 批量处理 classifier = self.cached_models['emotion_classifier'] results = classifier(processed_texts, truncation=True, max_length=512, batch_size=4) # 确保返回格式一致 emotions = [] for result in results: if isinstance(result, list) and result: emotions.append({ "label": result[0]['label'], "score": result[0]['score'] }) else: emotions.append({ "label": "中性", "score": 0.0 }) return emotions def check_opening(self, text, opening_keywords): """检查开场白(使用正则表达式提高准确性)""" if not text or not opening_keywords: return False pattern = "|".join(re.escape(k) for k in opening_keywords) return bool(re.search(pattern, text)) def check_closing(self, text, closing_keywords): """检查结束语(使用正则表达式提高准确性)""" if not text or not closing_keywords: return False pattern = "|".join(re.escape(k) for k in closing_keywords) return bool(re.search(pattern, text)) def check_forbidden(self, text, forbidden_keywords): """检查服务禁语(使用正则表达式提高准确性)""" if not text or not forbidden_keywords: return False pattern = "|".join(re.escape(k) for k in forbidden_keywords) return bool(re.search(pattern, text)) def analyze_speech_rate(self, segments): """改进的语速分析 - 基于实际识别文本""" if not segments: return 0 total_chars = 0 total_duration = 0 for segment in segments: # 计算片段时长(秒) duration = segment['duration'] total_duration += duration # 计算中文字符数(去除标点和空格) chinese_chars = sum(1 for char in segment['text'] if '\u4e00' <= char <= '\u9fff') total_chars += chinese_chars if total_duration == 0: return 0 # 语速 = 总字数 / 总时长(分钟) return total_chars / (total_duration / 60) def analyze_volume(self, wav_file, segments, sample_rate): """改进的音量分析 - 使用librosa计算RMS分贝值""" if not segments: return {"mean": -60, "std": 0} # 使用soundfile加载音频(更高效) try: y, sr = sf.read(wav_file, dtype='float32') if sr != sample_rate: y = librosa.resample(y, orig_sr=sr, target_sr=sample_rate) sr = sample_rate except Exception: # 回退到librosa y, sr = librosa.load(wav_file, sr=sample_rate, mono=True) all_dB = [] for start, end in segments: start_sample = int(start * sr) end_sample = int(end * sr) # 确保片段在有效范围内 if start_sample < len(y) and end_sample <= len(y): segment_audio = y[start_sample:end_sample] # 计算RMS并转换为dB rms = librosa.feature.rms(y=segment_audio)[0] dB = librosa.amplitude_to_db(rms, ref=1.0) # 使用标准参考值 all_dB.extend(dB) if not all_dB: return {"mean": -60, "std": 0} return { "mean": float(np.mean(all_dB)), "std": float(np.std(all_dB)) } def analyze_resolution(self, agent_text, customer_text, resolution_keywords): """分析问题解决率(使用更智能的匹配)""" # 检查客户是否提到问题 problem_patterns = [ "问题", "故障", "解决", "怎么办", "如何", "为什么", "不行", "不能", "无法", "错误", "bug", "issue", "疑问", "咨询" ] problem_regex = re.compile("|".join(problem_patterns)) has_problem = bool(problem_regex.search(customer_text)) # 检查客服是否提供解决方案 solution_regex = re.compile("|".join(re.escape(k) for k in resolution_keywords)) solution_found = bool(solution_regex.search(agent_text)) # 如果没有检测到问题,则认为已解决 if not has_problem: return True return solution_found def stop(self): """停止分析""" self.running = False self.message.emit("正在停止分析...") class MainWindow(QMainWindow): def __init__(self): super().__init__() self.setWindowTitle("外呼电话录音包质检分析系统") self.setGeometry(100, 100, 1000, 700) self.setStyleSheet(""" QMainWindow { background-color: #f0f0f0; } QGroupBox { font-weight: bold; border: 1px solid gray; border-radius: 5px; margin-top: 1ex; } QGroupBox::title { subcontrol-origin: margin; left: 10px; padding: 0 5px; } QPushButton { background-color: #4CAF50; color: white; border: none; padding: 5px 10px; border-radius: 3px; } QPushButton:hover { background-color: #45a049; } QPushButton:disabled { background-color: #cccccc; } QProgressBar { border: 1px solid grey; border-radius: 3px; text-align: center; } QProgressBar::chunk { background-color: #4CAF50; width: 10px; } QTextEdit { font-family: Consolas, Monaco, monospace; } """) # 初始化变量 self.audio_files = [] self.keyword_file = "" self.whisper_model_path = "./models/whisper-small" self.pyannote_model_path = "./models/pyannote-speaker-diarization" self.emotion_model_path = "./models/Erlangshen-Roberta-110M-Sentiment" self.output_dir = os.path.expanduser("~/质检报告") # 创建主控件 central_widget = QWidget() self.setCentralWidget(central_widget) main_layout = QVBoxLayout(central_widget) main_layout.setSpacing(10) main_layout.setContentsMargins(15, 15, 15, 15) # 文件选择区域 file_group = QGroupBox("文件选择") file_layout = QVBoxLayout(file_group) file_layout.setSpacing(8) # 音频文件选择 audio_layout = QHBoxLayout() self.audio_label = QLabel("音频文件/文件夹:") audio_layout.addWidget(self.audio_label) self.audio_path_edit = QLineEdit() self.audio_path_edit.setPlaceholderText("请选择音频文件或文件夹") audio_layout.addWidget(self.audio_path_edit, 3) self.audio_browse_btn = QPushButton("浏览...") self.audio_browse_btn.clicked.connect(self.browse_audio) audio_layout.addWidget(self.audio_browse_btn) file_layout.addLayout(audio_layout) # 关键词文件选择 keyword_layout = QHBoxLayout() self.keyword_label = QLabel("关键词文件:") keyword_layout.addWidget(self.keyword_label) self.keyword_path_edit = QLineEdit() self.keyword_path_edit.setPlaceholderText("请选择Excel格式的关键词文件") keyword_layout.addWidget(self.keyword_path_edit, 3) self.keyword_browse_btn = QPushButton("浏览...") self.keyword_browse_btn.clicked.connect(self.browse_keyword) keyword_layout.addWidget(self.keyword_browse_btn) file_layout.addLayout(keyword_layout) main_layout.addWidget(file_group) # 模型设置区域 model_group = QGroupBox("模型设置") model_layout = QVBoxLayout(model_group) model_layout.setSpacing(8) # Whisper模型路径 whisper_layout = QHBoxLayout() whisper_layout.addWidget(QLabel("Whisper模型路径:")) self.whisper_edit = QLineEdit(self.whisper_model_path) whisper_layout.addWidget(self.whisper_edit, 3) model_layout.addLayout(whisper_layout) # Pyannote模型路径 pyannote_layout = QHBoxLayout() pyannote_layout.addWidget(QLabel("Pyannote模型路径:")) self.pyannote_edit = QLineEdit(self.pyannote_model_path) pyannote_layout.addWidget(self.pyannote_edit, 3) model_layout.addLayout(pyannote_layout) # 情感分析模型路径 emotion_layout = QHBoxLayout() emotion_layout.addWidget(QLabel("情感分析模型路径:")) self.emotion_edit = QLineEdit(self.emotion_model_path) emotion_layout.addWidget(self.emotion_edit, 3) model_layout.addLayout(emotion_layout) # 输出目录 output_layout = QHBoxLayout() output_layout.addWidget(QLabel("输出目录:")) self.output_edit = QLineEdit(self.output_dir) self.output_edit.setPlaceholderText("请选择报告输出目录") output_layout.addWidget(self.output_edit, 3) self.output_browse_btn = QPushButton("浏览...") self.output_browse_btn.clicked.connect(self.browse_output) output_layout.addWidget(self.output_browse_btn) model_layout.addLayout(output_layout) main_layout.addWidget(model_group) # 控制按钮区域 control_layout = QHBoxLayout() control_layout.setSpacing(10) self.start_btn = QPushButton("开始分析") self.start_btn.setStyleSheet("background-color: #2196F3;") self.start_btn.clicked.connect(self.start_analysis) control_layout.addWidget(self.start_btn) self.stop_btn = QPushButton("停止分析") self.stop_btn.setStyleSheet("background-color: #f44336;") self.stop_btn.clicked.connect(self.stop_analysis) self.stop_btn.setEnabled(False) control_layout.addWidget(self.stop_btn) self.clear_btn = QPushButton("清空") self.clear_btn.clicked.connect(self.clear_all) control_layout.addWidget(self.clear_btn) main_layout.addLayout(control_layout) # 进度条 self.progress_bar = QProgressBar() self.progress_bar.setValue(0) self.progress_bar.setFormat("就绪") self.progress_bar.setMinimumHeight(25) main_layout.addWidget(self.progress_bar) # 日志输出区域 log_group = QGroupBox("分析日志") log_layout = QVBoxLayout(log_group) self.log_text = QTextEdit() self.log_text.setReadOnly(True) log_layout.addWidget(self.log_text) main_layout.addWidget(log_group, 1) # 给日志区域更多空间 # 状态区域 status_layout = QHBoxLayout() self.status_label = QLabel("状态: 就绪") status_layout.addWidget(self.status_label, 1) self.file_count_label = QLabel("已选择0个音频文件") status_layout.addWidget(self.file_count_label) main_layout.addLayout(status_layout) # 初始化分析线程 self.analysis_thread = None def browse_audio(self): """浏览音频文件或文件夹""" options = QFileDialog.Options() files, _ = QFileDialog.getOpenFileNames( self, "选择音频文件", "", "音频文件 (*.mp3 *.wav *.amr *.ogg *.flac *.m4a);;所有文件 (*)", options=options ) if files: self.audio_files = files self.audio_path_edit.setText("; ".join(files)) self.file_count_label.setText(f"已选择{len(files)}个音频文件") self.log_text.append(f"已选择{len(files)}个音频文件") def browse_keyword(self): """浏览关键词文件""" options = QFileDialog.Options() file, _ = QFileDialog.getOpenFileName( self, "选择关键词文件", "", "Excel文件 (*.xlsx *.xls);;所有文件 (*)", options=options ) if file: self.keyword_file = file self.keyword_path_edit.setText(file) self.log_text.append(f"已选择关键词文件: {file}") def browse_output(self): """浏览输出目录""" options = QFileDialog.Options() directory = QFileDialog.getExistingDirectory( self, "选择输出目录", self.output_dir, options=options ) if directory: self.output_dir = directory self.output_edit.setText(directory) self.log_text.append(f"输出目录设置为: {directory}") def start_analysis(self): """开始分析""" if not self.audio_files: self.show_warning("请先选择音频文件") return if not self.keyword_file: self.show_warning("请先选择关键词文件") return if not os.path.exists(self.keyword_file): self.show_warning("关键词文件不存在,请重新选择") return # 检查模型路径 model_paths = [ self.whisper_edit.text(), self.pyannote_edit.text(), self.emotion_edit.text() ] for path in model_paths: if not os.path.exists(path): self.show_warning(f"模型路径不存在: {path}") return # 更新模型路径 self.whisper_model_path = self.whisper_edit.text() self.pyannote_model_path = self.pyannote_edit.text() self.emotion_model_path = self.emotion_edit.text() self.output_dir = self.output_edit.text() # 创建输出目录 os.makedirs(self.output_dir, exist_ok=True) self.log_text.append("开始分析...") self.start_btn.setEnabled(False) self.stop_btn.setEnabled(True) self.status_label.setText("状态: 分析中...") self.progress_bar.setFormat("分析中... 0%") self.progress_bar.setValue(0) # 创建并启动分析线程 self.analysis_thread = AnalysisThread( self.audio_files, self.keyword_file, self.whisper_model_path, self.pyannote_model_path, self.emotion_model_path ) self.analysis_thread.progress.connect(self.update_progress) self.analysis_thread.message.connect(self.log_text.append) self.analysis_thread.analysis_complete.connect(self.on_analysis_complete) self.analysis_thread.error.connect(self.on_analysis_error) self.analysis_thread.finished.connect(self.on_analysis_finished) self.analysis_thread.start() def update_progress(self, value): """更新进度条""" self.progress_bar.setValue(value) self.progress_bar.setFormat(f"分析中... {value}%") def stop_analysis(self): """停止分析""" if self.analysis_thread and self.analysis_thread.isRunning(): self.analysis_thread.stop() self.log_text.append("正在停止分析...") self.stop_btn.setEnabled(False) def clear_all(self): """清空所有内容""" self.audio_files = [] self.keyword_file = "" self.audio_path_edit.clear() self.keyword_path_edit.clear() self.log_text.clear() self.progress_bar.setValue(0) self.progress_bar.setFormat("就绪") self.status_label.setText("状态: 就绪") self.file_count_label.setText("已选择0个音频文件") self.log_text.append("已清空所有内容") def show_warning(self, message): """显示警告消息""" QMessageBox.warning(self, "警告", message) self.log_text.append(f"警告: {message}") def on_analysis_complete(self, result): """分析完成处理""" try: self.log_text.append("正在生成报告...") if not result.get("results"): self.log_text.append("警告: 没有生成任何分析结果") return # 生成Excel报告 excel_path = os.path.join(self.output_dir, "质检分析报告.xlsx") self.generate_excel_report(result, excel_path) # 生成Word报告 word_path = os.path.join(self.output_dir, "质检分析报告.docx") self.generate_word_report(result, word_path) self.log_text.append(f"分析报告已保存至: {excel_path}") self.log_text.append(f"可视化报告已保存至: {word_path}") self.log_text.append("分析完成!") self.status_label.setText(f"状态: 分析完成!报告保存至: {self.output_dir}") self.progress_bar.setFormat("分析完成!") # 显示完成消息 QMessageBox.information( self, "分析完成", f"分析完成!报告已保存至:\n{excel_path}\n{word_path}" ) except Exception as e: import traceback error_msg = f"生成报告时出错: {str(e)}\n{traceback.format_exc()}" self.log_text.append(error_msg) logger.error(error_msg) def on_analysis_error(self, message): """分析错误处理""" self.log_text.append(f"错误: {message}") self.status_label.setText("状态: 发生错误") self.progress_bar.setFormat("发生错误") QMessageBox.critical(self, "分析错误", message) def on_analysis_finished(self): """分析线程结束处理""" self.start_btn.setEnabled(True) self.stop_btn.setEnabled(False) def generate_excel_report(self, result, output_path): """生成Excel报告(增强健壮性)""" try: # 从结果中提取数据 data = [] for res in result['results']: data.append({ "文件名": res['file_name'], "音频时长(秒)": res['duration'], "开场白检查": "通过" if res['opening_check'] else "未通过", "结束语检查": "通过" if res['closing_check'] else "未通过", "服务禁语检查": "通过" if not res['forbidden_check'] else "未通过", "客服情感": res['agent_emotion']['label'], "客服情感得分": res['agent_emotion']['score'], "客户情感": res['customer_emotion']['label'], "客户情感得分": res['customer_emotion']['score'], "语速(字/分)": res['speech_rate'], "平均音量(dB)": res['volume_mean'], "音量标准差": res['volume_std'], "问题解决率": "是" if res['resolution_rate'] else "否" }) # 创建DataFrame并保存 df = pd.DataFrame(data) # 尝试使用openpyxl引擎(更稳定) try: df.to_excel(output_path, index=False, engine='openpyxl') except ImportError: df.to_excel(output_path, index=False) # 添加汇总统计 try: with pd.ExcelWriter(output_path, engine='openpyxl', mode='a', if_sheet_exists='replace') as writer: summary_data = { "统计项": ["总文件数", "开场白通过率", "结束语通过率", "服务禁语通过率", "问题解决率"], "数值": [ len(result['results']), df['开场白检查'].value_counts().get('通过', 0) / len(df), df['结束语检查'].value_counts().get('通过', 0) / len(df), df['服务禁语检查'].value_counts().get('通过', 0) / len(df), df['问题解决率'].value_counts().get('是', 0) / len(df) ] } summary_df = pd.DataFrame(summary_data) summary_df.to_excel(writer, sheet_name='汇总统计', index=False) except Exception as e: self.log_text.append(f"添加汇总统计时出错: {str(e)}") except Exception as e: raise Exception(f"生成Excel报告失败: {str(e)}") def generate_word_report(self, result, output_path): """生成Word报告(增强健壮性)""" try: doc = Document() # 添加标题 doc.add_heading('外呼电话录音质检分析报告', 0) # 添加基本信息 doc.add_heading('分析概况', level=1) doc.add_paragraph(f"分析时间: {time.strftime('%Y-%m-%d %H:%M:%S')}") doc.add_paragraph(f"分析文件数量: {len(result['results'])}") doc.add_paragraph(f"关键词文件: {os.path.basename(self.keyword_file)}") # 添加汇总统计 doc.add_heading('汇总统计', level=1) # 创建汇总表格 table = doc.add_table(rows=5, cols=2) table.style = 'Table Grid' # 表头 hdr_cells = table.rows[0].cells hdr_cells[0].text = '统计项' hdr_cells[1].text = '数值' # 计算统计数据 df = pd.DataFrame(result['results']) pass_rates = { "开场白通过率": df['opening_check'].mean() if not df.empty else 0, "结束语通过率": df['closing_check'].mean() if not df.empty else 0, "服务禁语通过率": (1 - df['forbidden_check']).mean() if not df.empty else 0, "问题解决率": df['resolution_rate'].mean() if not df.empty else 0 } # 填充表格 rows = [ ("总文件数", len(result['results'])), ("开场白通过率", f"{pass_rates['开场白通过率']:.2%}"), ("结束语通过率", f"{pass_rates['结束语通过率']:.2%}"), ("服务禁语通过率", f"{pass_rates['服务禁语通过率']:.2%}"), ("问题解决率", f"{pass_rates['问题解决率']:.2%}") ] for i, row_data in enumerate(rows): if i < len(table.rows): row_cells = table.rows[i].cells row_cells[0].text = row_data[0] row_cells[1].text = str(row_data[1]) # 添加情感分析图表 if result['results']: doc.add_heading('情感分析', level=1) # 客服情感分布 agent_emotions = [res['agent_emotion']['label'] for res in result['results']] agent_emotion_counts = pd.Series(agent_emotions).value_counts() if not agent_emotion_counts.empty: fig, ax = plt.subplots(figsize=(6, 4)) agent_emotion_counts.plot.pie(autopct='%1.1f%%', ax=ax) ax.set_title('客服情感分布') ax.set_ylabel('') # 移除默认的ylabel plt.tight_layout() # 保存图表到临时文件 chart_path = os.path.join(self.output_dir, "agent_emotion_chart.png") plt.savefig(chart_path, dpi=100, bbox_inches='tight') plt.close() doc.add_picture(chart_path, width=Inches(4)) doc.add_paragraph('图1: 客服情感分布') # 客户情感分布 customer_emotions = [res['customer_emotion']['label'] for res in result['results']] customer_emotion_counts = pd.Series(customer_emotions).value_counts() if not customer_emotion_counts.empty: fig, ax = plt.subplots(figsize=(6, 4)) customer_emotion_counts.plot.pie(autopct='%1.1f%%', ax=ax) ax.set_title('客户情感分布') ax.set_ylabel('') # 移除默认的ylabel plt.tight_layout() chart_path = os.path.join(self.output_dir, "customer_emotion_chart.png") plt.savefig(chart_path, dpi=100, bbox_inches='tight') plt.close() doc.add_picture(chart_path, width=Inches(4)) doc.add_paragraph('图2: 客户情感分布') # 添加详细分析结果 doc.add_heading('详细分析结果', level=1) # 创建详细表格 table = doc.add_table(rows=1, cols=6) table.style = 'Table Grid' # 表头 hdr_cells = table.rows[0].cells headers = ['文件名', '开场白', '结束语', '禁语', '客服情感', '问题解决'] for i, header in enumerate(headers): hdr_cells[i].text = header # 填充数据 for res in result['results']: row_cells = table.add_row().cells row_cells[0].text = res['file_name'] row_cells[1].text = "✓" if res['opening_check'] else "✗" row_cells[2].text = "✓" if res['closing_check'] else "✗" row_cells[3].text = "✗" if res['forbidden_check'] else "✓" row_cells[4].text = res['agent_emotion']['label'] row_cells[5].text = "✓" if res['resolution_rate'] else "✗" # 保存文档 doc.save(output_path) except Exception as e: raise Exception(f"生成Word报告失败: {str(e)}") if __name__ == "__main__": # 检查是否安装了torch try: import torch except ImportError: print("警告: PyTorch 未安装,情感分析可能无法使用GPU加速") app = QApplication(sys.argv) # 设置应用样式 app.setStyle("Fusion") window = MainWindow() window.show() sys.exit(app.exec_())

from BaseNN import nn import numpy as np import torch # 初始化模型 model = nn('cls') # 正确加载模型(三步法) # 1. 加载权重字典 state_dict = torch.load('model.pth', map_location='cpu') # 确保CPU兼容性[^1] # 2. 重建模型结构(必须与训练时一致) # 示例结构,需替换为您的实际结构 model.add('linear', input_size=100, output_size=50, activation='relu') model.add('linear', input_size=50, output_size=10, activation='softmax') # 3. 将权重注入模型 model.model.load_state_dict(state_dict) # 关键步骤[^2] # 设置推理设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.model = model.model.to(device) # 转移模型到设备[^3] model.model.eval() # 设置为评估模式 # 执行推理 input_data = '长' with torch.no_grad(): # 禁用梯度计算 result = model.inference(data=input_data) # 处理输出 output = result[0].detach().cpu().numpy() # 确保数据在CPU上 print("output: ", output) index = np.argmax(output) w = model.ix2word[index] # 确保ix2word已定义 print("word:", w) --------------------------------------------------------------------------- KeyError Traceback (most recent call last) Cell In[14], line 14 10 state_dict = torch.load('model.pth', map_location='cpu') # 确保CPU兼容性[^1] 12 # 2. 重建模型结构(必须与训练时一致) 13 # 示例结构,需替换为您的实际结构 ---> 14 model.add('linear', input_size=100, output_size=50, activation='relu') 15 model.add('linear', input_size=50, output_size=10, activation='softmax') 17 # 3. 将权重注入模型 File /opt/conda/envs/mmedu/lib/python3.8/site-packages/BaseNN/BaseNN.py:308, in nn.add(self, layer, activation, optimizer, **kw) 306 if layer == 'linear': 307 self.model.add_module('reshape', Reshape(self.batchsize)) --> 308 self.model.add_module('linear' + str(self.layers_num), torch.nn.Linear(kw['size'][0], kw['size'][1])) 309 self.last_channel = kw['size'][1] 310 print("增加全连接层,输入维度:{},输出维度:{}。".format(kw['size'][0], kw['size'][1])) KeyError: 'size'

class BayesianNN(nn.Module): def __init__(self, input_dim=8, hidden_dim=50, output_dim=1): super().__init__() self.fc1 = BayesianLinear(input_dim, hidden_dim) self.fc2 = BayesianLinear(hidden_dim, hidden_dim) self.fc3 = BayesianLinear(hidden_dim, output_dim) self.relu = nn.ReLU() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x def kl_loss(self): return self.fc1.kl_loss() + self.fc2.kl_loss() + self.fc3.kl_loss() def observation_noise(self): # 假设我们只使用最后一层的观测噪声 return self.fc3.observation_noise() # 初始化模型 model = BayesianNN() device = torch.device("cuda") model.to(device) optimizer = optim.Adam(model.parameters(), lr=0.01) def train(model, train_loader, optimizer, epochs=500, kl_weight=0.1): device = next(model.parameters()).device model.train() # 确保模型处于训练模式 # 训练循环 for epoch in range(epochs): total_nll = 0.0 total_kl = 0.0 total_loss = 0.0 num_batches = 0 # 迭代每个批次 for batch_idx, (x_batch, y_batch) in enumerate(train_loader): # 将数据移至模型所在的设备(GPU/CPU) x_batch = x_batch.to(device) y_batch = y_batch.to(device) # 梯度清零 optimizer.zero_grad() # 前向传播(采样) outputs = model(x_batch) # 计算观测噪声 - 确保是标量或可广播的张量 sigma = model.observation_noise() # 计算负对数似然损失 # 修正:使用正确的维度计算 squared_diff = (outputs - y_batch).pow(2) # 确保sigma有正确的维度用于广播 if sigma.dim() == 0: # 标量噪声 nll_loss = (0.5 * squared_diff / sigma.pow(2)).mean() + sigma.log() else: # 向量噪声 - 需要维度匹配 # 假设sigma形状为(batch_size, 1)或(batch_size, output_dim) sigma = sigma.view_as(outputs) # 确保形状匹配 nll_loss = (0.5 * squared_diff / sigma.pow(2)).mean() + sigma.log().mean() # 计算KL散度 kl_loss = model.kl_loss() # 总损失 = NLL + KL正则项 batch_loss = nll_loss + kl_weight * kl_loss # 反向传播 batch_loss.backward() # 梯度裁剪(防止梯度爆炸) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() # 累计损失用于日志 total_nll += nll_loss.item() total_kl += kl_loss.item() total_loss += batch_loss.item() num_batches += 1 # 计算平均损失 avg_nll = total_nll / num_batches avg_kl = total_kl / num_batches avg_loss = total_loss / num_batches # 每50个epoch打印一次 if epoch % 50 == 0: print(f"Epoch {epoch}: Avg NLL={avg_nll:.4f}, Avg KL={avg_kl:.4f}, Avg Total={avg_loss:.4f}") # 训练完成后保存模型 torch.save(model.state_dict(), 'bayesian_model.pth') print("Training completed. Model saved.") def visualize_predictions(model, data_loader, title="Predictions vs Ground Truth", num_samples=1000): model.eval() # 设置为评估模式 # 收集所有预测结果 all_preds = [] all_targets = [] all_inputs = [] with torch.no_grad(): for x_batch, y_batch in data_loader: x_batch = x_batch.to(model.device) y_batch = y_batch.to(model.device) # 获取预测结果 preds = model(x_batch) # 收集数据 all_preds.append(preds.cpu().numpy()) all_targets.append(y_batch.cpu().numpy()) all_inputs.append(x_batch.cpu().numpy()) # 如果样本足够,提前停止 if len(all_preds) * data_loader.batch_size >= num_samples: break # 合并数据 preds = np.concatenate(all_preds, axis=0) targets = np.concatenate(all_targets, axis=0) inputs = np.concatenate(all_inputs, axis=0) # 如果数据是多维的,只取第一个特征用于可视化 if inputs.ndim > 2: inputs = inputs[:, 0] preds = preds[:, 0] targets = targets[:, 0] # 创建图形 plt.figure(figsize=(15, 8)) # 绘制真实值 plt.plot(targets[:num_samples], 'b-', label='Ground Truth', alpha=0.7) # 绘制预测值 plt.plot(preds[:num_samples], 'r--', label='Predictions', alpha=0.8) # 添加不确定区间(如果模型支持) if hasattr(model, 'predict_with_uncertainty'): mean, std = model.predict_with_uncertainty(inputs[:num_samples]) plt.fill_between( np.arange(len(mean)), mean - 2 * std, mean + 2 * std, color='orange', alpha=0.3, label='95% Confidence Interval' ) plt.title(title) plt.xlabel('Sample Index') plt.ylabel('Target Value') plt.legend() plt.grid(True, linestyle='--', alpha=0.6) plt.tight_layout() plt.show() # 返回数据用于进一步分析 return inputs[:num_samples], targets[:num_samples], preds[:num_samples] # 使用示例 if __name__ == "__main__": # 假设模型和优化器已经定义 # model = YourModel(...) # optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # 训练模型 train(model, train_loader, optimizer, epochs=100, kl_weight=0.1) # 可视化训练集预测 print("\nVisualizing training predictions:") train_inputs, train_targets, train_preds = visualize_predictions( model, train_loader, title="Training Set Predictions vs Ground Truth" ) # 可视化验证集预测 print("\nVisualizing validation predictions:") val_inputs, val_targets, val_preds = visualize_predictions( model, val_loader, title="Validation Set Predictions vs Ground Truth" ) # 可视化测试集预测 print("\nVisualizing test predictions:") test_inputs, test_targets, test_preds = visualize_predictions( model, test_loader, title="Test Set Predictions vs Ground Truth" ) # 可选:计算评估指标 def calculate_metrics(targets, preds, set_name=""): mae = np.mean(np.abs(targets - preds)) rmse = np.sqrt(np.mean((targets - preds) ** 2)) print(f"\n{set_name} Metrics:") print(f" MAE: {mae:.4f}") print(f" RMSE: {rmse:.4f}") return mae, rmse train_mae, train_rmse = calculate_metrics(train_targets, train_preds, "Training Set") val_mae, val_rmse = calculate_metrics(val_targets, val_preds, "Validation Set") test_mae, test_rmse = calculate_metrics(test_targets, test_preds, "Test Set") Traceback (most recent call last): File "F:\PythonProject\学习python.py", line 959, in <module> train(model, train_loader, optimizer, epochs=100, kl_weight=0.1) File "F:\PythonProject\学习python.py", line 846, in train sigma = sigma.view_as(outputs) # 确保形状匹配 ^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: shape '[798, 1]' is invalid for input of size 1 有问题,修改代码

最新推荐

recommend-type

洛克力量R8.4V2电脑DSP调音软件下载

洛克力量R8.4V2电脑DSP调音软件下载
recommend-type

掌握XFireSpring整合技术:HELLOworld原代码使用教程

标题:“xfirespring整合使用原代码”中提到的“xfirespring”是指将XFire和Spring框架进行整合使用。XFire是一个基于SOAP的Web服务框架,而Spring是一个轻量级的Java/Java EE全功能栈的应用程序框架。在Web服务开发中,将XFire与Spring整合能够发挥两者的优势,例如Spring的依赖注入、事务管理等特性,与XFire的简洁的Web服务开发模型相结合。 描述:“xfirespring整合使用HELLOworld原代码”说明了在这个整合过程中实现了一个非常基本的Web服务示例,即“HELLOworld”。这通常意味着创建了一个能够返回"HELLO world"字符串作为响应的Web服务方法。这个简单的例子用来展示如何设置环境、编写服务类、定义Web服务接口以及部署和测试整合后的应用程序。 标签:“xfirespring”表明文档、代码示例或者讨论集中于XFire和Spring的整合技术。 文件列表中的“index.jsp”通常是一个Web应用程序的入口点,它可能用于提供一个用户界面,通过这个界面调用Web服务或者展示Web服务的调用结果。“WEB-INF”是Java Web应用中的一个特殊目录,它存放了应用服务器加载的Servlet类文件和相关的配置文件,例如web.xml。web.xml文件中定义了Web应用程序的配置信息,如Servlet映射、初始化参数、安全约束等。“META-INF”目录包含了元数据信息,这些信息通常由部署工具使用,用于描述应用的元数据,如manifest文件,它记录了归档文件中的包信息以及相关的依赖关系。 整合XFire和Spring框架,具体知识点可以分为以下几个部分: 1. XFire框架概述 XFire是一个开源的Web服务框架,它是基于SOAP协议的,提供了一种简化的方式来创建、部署和调用Web服务。XFire支持多种数据绑定,包括XML、JSON和Java数据对象等。开发人员可以使用注解或者基于XML的配置来定义服务接口和服务实现。 2. Spring框架概述 Spring是一个全面的企业应用开发框架,它提供了丰富的功能,包括但不限于依赖注入、面向切面编程(AOP)、数据访问/集成、消息传递、事务管理等。Spring的核心特性是依赖注入,通过依赖注入能够将应用程序的组件解耦合,从而提高应用程序的灵活性和可测试性。 3. XFire和Spring整合的目的 整合这两个框架的目的是为了利用各自的优势。XFire可以用来创建Web服务,而Spring可以管理这些Web服务的生命周期,提供企业级服务,如事务管理、安全性、数据访问等。整合后,开发者可以享受Spring的依赖注入、事务管理等企业级功能,同时利用XFire的简洁的Web服务开发模型。 4. XFire与Spring整合的基本步骤 整合的基本步骤可能包括添加必要的依赖到项目中,配置Spring的applicationContext.xml,以包括XFire特定的bean配置。比如,需要配置XFire的ServiceExporter和ServicePublisher beans,使得Spring可以管理XFire的Web服务。同时,需要定义服务接口以及服务实现类,并通过注解或者XML配置将其关联起来。 5. Web服务实现示例:“HELLOworld” 实现一个Web服务通常涉及到定义服务接口和服务实现类。服务接口定义了服务的方法,而服务实现类则提供了这些方法的具体实现。在XFire和Spring整合的上下文中,“HELLOworld”示例可能包含一个接口定义,比如`HelloWorldService`,和一个实现类`HelloWorldServiceImpl`,该类有一个`sayHello`方法返回"HELLO world"字符串。 6. 部署和测试 部署Web服务时,需要将应用程序打包成WAR文件,并部署到支持Servlet 2.3及以上版本的Web应用服务器上。部署后,可以通过客户端或浏览器测试Web服务的功能,例如通过访问XFire提供的服务描述页面(WSDL)来了解如何调用服务。 7. JSP与Web服务交互 如果在应用程序中使用了JSP页面,那么JSP可以用来作为用户与Web服务交互的界面。例如,JSP可以包含JavaScript代码来发送异步的AJAX请求到Web服务,并展示返回的结果给用户。在这个过程中,JSP页面可能使用XMLHttpRequest对象或者现代的Fetch API与Web服务进行通信。 8. 项目配置文件说明 项目配置文件如web.xml和applicationContext.xml分别在Web应用和服务配置中扮演关键角色。web.xml负责定义Web组件,比如Servlet、过滤器和监听器,而applicationContext.xml则负责定义Spring容器中的bean,包括数据源、事务管理器、业务逻辑组件和服务访问器等。 总之,通过上述整合使用原代码的知识点,可以深入理解XFire与Spring框架的结合使用,以及如何开发和部署基本的Web服务。这些技术知识有助于进行更高层次的Web服务开发,以及在复杂的IT环境中灵活运用各种框架和工具。
recommend-type

【Unity2018汉化大揭秘】:一步到位优化中文用户体验

# 摘要 本论文详细介绍了Unity2018汉化项目的实施过程,从理论准备到实践步骤,再到后期的测试与优化,最终分享了汉化经验和心得。首先,强调了汉化项目对于用户体验提升和产品国际化的重要性,并分析了汉化资源的分类、识别与管理方法。接着,详细阐述了汉化工具的选择、脚本编写、资源手动调整等实践步骤,以及汉化质量测试、体验优化和项目维护更新的重要性。论文最后通过案例研究,分析了汉化中的挑战和应对策略,并预测了汉化技术未来的发展方向。整个汉化过程不仅涉及技术层面的操作,还包含了深入的理论研究和实践心得分享,对提升汉化项目的效率和质量具有指导意义。 # 关键字 Unity汉化;用户体验;国际化;资源
recommend-type

iPhone

<think>我们注意到用户意图不明确,但提到了“照片提取”和“其他功能帮助”。因此,我们需要通过搜索来获取关于iPhone照片提取的常见方法以及其他可能的功能帮助。由于用户问题比较宽泛,我们将重点放在照片提取上,因为这是明确提到的关键词。同时,我们也会考虑一些其他常用功能的帮助。首先,针对照片提取,可能涉及从iPhone导出照片、从备份中提取照片、或者从损坏的设备中恢复照片等。我们将搜索这些方面的信息。其次,关于其他功能帮助,我们可以提供一些常见问题的快速指南,如电池优化、屏幕时间管理等。根据要求,我们需要将答案组织为多个方法或步骤,并在每个步骤间换行。同时,避免使用第一人称和步骤词汇。由于
recommend-type

驾校一点通软件:提升驾驶证考试通过率

标题“驾校一点通”指向的是一款专门为学员考取驾驶证提供帮助的软件,该软件强调其辅助性质,旨在为学员提供便捷的学习方式和复习资料。从描述中可以推断出,“驾校一点通”是一个与驾驶考试相关的应用软件,这类软件一般包含驾驶理论学习、模拟考试、交通法规解释等内容。 文件标题中的“2007”这个年份标签很可能意味着软件的最初发布时间或版本更新年份,这说明了软件具有一定的历史背景和可能经过了多次更新,以适应不断变化的驾驶考试要求。 压缩包子文件的文件名称列表中,有以下几个文件类型值得关注: 1. images.dat:这个文件名表明,这是一个包含图像数据的文件,很可能包含了用于软件界面展示的图片,如各种标志、道路场景等图形。在驾照学习软件中,这类图片通常用于帮助用户认识和记忆不同交通标志、信号灯以及驾驶过程中需要注意的各种道路情况。 2. library.dat:这个文件名暗示它是一个包含了大量信息的库文件,可能包含了法规、驾驶知识、考试题库等数据。这类文件是提供给用户学习驾驶理论知识和准备科目一理论考试的重要资源。 3. 驾校一点通小型汽车专用.exe:这是一个可执行文件,是软件的主要安装程序。根据标题推测,这款软件主要是针对小型汽车驾照考试的学员设计的。通常,小型汽车(C1类驾照)需要学习包括车辆构造、基础驾驶技能、安全行车常识、交通法规等内容。 4. 使用说明.html:这个文件是软件使用说明的文档,通常以网页格式存在,用户可以通过浏览器阅读。使用说明应该会详细介绍软件的安装流程、功能介绍、如何使用软件的各种模块以及如何通过软件来帮助自己更好地准备考试。 综合以上信息,我们可以挖掘出以下几个相关知识点: - 软件类型:辅助学习软件,专门针对驾驶考试设计。 - 应用领域:主要用于帮助驾考学员准备理论和实践考试。 - 文件类型:包括图片文件(images.dat)、库文件(library.dat)、可执行文件(.exe)和网页格式的说明文件(.html)。 - 功能内容:可能包含交通法规知识学习、交通标志识别、驾驶理论学习、模拟考试、考试题库练习等功能。 - 版本信息:软件很可能最早发布于2007年,后续可能有多个版本更新。 - 用户群体:主要面向小型汽车驾照考生,即C1类驾照学员。 - 使用方式:用户需要将.exe安装文件进行安装,然后根据.html格式的使用说明来熟悉软件操作,从而利用images.dat和library.dat中的资源来辅助学习。 以上知识点为从给定文件信息中提炼出来的重点,这些内容对于了解“驾校一点通”这款软件的功能、作用、使用方法以及它的发展历史都有重要的指导意义。
recommend-type

【DFLauncher自动化教程】:简化游戏启动流程,让游戏体验更流畅

# 摘要 DFLauncher是一个功能丰富的游戏启动和管理平台,本论文将介绍其安装、基础使用、高级设置、社区互动以及插件开发等方面。通过对配置文件的解析、界面定制、自动化功能的实现、高级配置选项、安全性和性能监控的详细讨论,本文阐述了DFLauncher如何帮助用户更高效地管理和优化游戏环境。此外,本文还探讨了DFLauncher社区的资源分享、教育教程和插件开发等内容,
recommend-type

自适应卡尔曼滤波是什么意思

<think>嗯,用户这次想深入了解自适应卡尔曼滤波(AKF)的概念和原理。从对话历史看,用户之前研究过相机自动对焦系统的降噪技术,其中提到过自适应卡尔曼滤波的应用场景。用户可能是个工程师或研究者,正在探索信号处理在具体场景的实现细节。 用户提供的三篇参考文献很有价值:[1]是基础理论综述,[2]聚焦多传感器场景,[3]讨论噪声协方差自适应方法。需要特别注意相机AF系统的特殊需求——实时性要求高(每秒数十次对焦计算)、噪声环境复杂(机械振动/弱光干扰),这些在解释原理时要结合具体案例。 技术要点需要分层解析:先明确标准卡尔曼滤波的局限(固定噪声参数),再展开自适应机制。对于相机AF场景,重
recommend-type

EIA-CEA 861B标准深入解析:时间与EDID技术

EIA-CEA 861B标准是美国电子工业联盟(Electronic Industries Alliance, EIA)和消费电子协会(Consumer Electronics Association, CEA)联合制定的一个技术规范,该规范详细规定了视频显示设备和系统之间的通信协议,特别是关于视频显示设备的时间信息(timing)和扩展显示识别数据(Extended Display Identification Data,简称EDID)的结构与内容。 在视频显示技术领域,确保不同品牌、不同型号的显示设备之间能够正确交换信息是至关重要的,而这正是EIA-CEA 861B标准所解决的问题。它为制造商提供了一个统一的标准,以便设备能够互相识别和兼容。该标准对于确保设备能够正确配置分辨率、刷新率等参数至关重要。 ### 知识点详解 #### EIA-CEA 861B标准的历史和重要性 EIA-CEA 861B标准是随着数字视频接口(Digital Visual Interface,DVI)和后来的高带宽数字内容保护(High-bandwidth Digital Content Protection,HDCP)等技术的发展而出现的。该标准之所以重要,是因为它定义了电视、显示器和其他显示设备之间如何交互时间参数和显示能力信息。这有助于避免兼容性问题,并确保消费者能有较好的体验。 #### Timing信息 Timing信息指的是关于视频信号时序的信息,包括分辨率、水平频率、垂直频率、像素时钟频率等。这些参数决定了视频信号的同步性和刷新率。正确配置这些参数对于视频播放的稳定性和清晰度至关重要。EIA-CEA 861B标准规定了多种推荐的视频模式(如VESA标准模式)和特定的时序信息格式,使得设备制造商可以参照这些标准来设计产品。 #### EDID EDID是显示设备向计算机或其他视频源发送的数据结构,包含了关于显示设备能力的信息,如制造商、型号、支持的分辨率列表、支持的视频格式、屏幕尺寸等。这种信息交流机制允许视频源设备能够“了解”连接的显示设备,并自动设置最佳的输出分辨率和刷新率,实现即插即用(plug and play)功能。 EDID的结构包含了一系列的块(block),其中定义了包括基本显示参数、色彩特性、名称和序列号等在内的信息。该标准确保了这些信息能以一种标准的方式被传输和解释,从而简化了显示设置的过程。 #### EIA-CEA 861B标准的应用 EIA-CEA 861B标准不仅适用于DVI接口,还适用于HDMI(High-Definition Multimedia Interface)和DisplayPort等数字视频接口。这些接口技术都必须遵循EDID的通信协议,以保证设备间正确交换信息。由于标准的广泛采用,它已经成为现代视频信号传输和显示设备设计的基础。 #### EIA-CEA 861B标准的更新 随着技术的进步,EIA-CEA 861B标准也在不断地更新和修订。例如,随着4K分辨率和更高刷新率的显示技术的发展,该标准已经扩展以包括支持这些新技术的时序和EDID信息。任何显示设备制造商在设计新产品时,都必须考虑最新的EIA-CEA 861B标准,以确保兼容性。 #### 结论 EIA-CEA 861B标准是电子显示领域的一个重要规范,它详细定义了视频显示设备在通信时所使用的信号时序和设备信息的格式。该标准的存在,使得不同厂商生产的显示设备可以无缝连接和集成,极大地增强了用户体验。对于IT专业人士而言,了解和遵守EIA-CEA 861B标准是进行视频系统设计、故障诊断及设备兼容性测试的重要基础。
recommend-type

【DFLauncher应用实战】:如何将DFLauncher融入矮人要塞并提升效率

# 摘要 DFLauncher是一款功能全面的游戏管理工具,旨在简化游戏安装、启动和维护过程。本文介绍了DFLauncher的基本使用方法,详细解析了其核心功能,包括游戏库管理、游戏配置优化、更新机制等。同时,文章探讨了DFLauncher在特定游戏“矮人要塞”中的集成应用,以及通过插件和脚本进行的高级定制。故障诊断与系统优化章节提供了实用的诊断方法和性能提升技巧。最后,本文展望了DFLauncher的未来发展方向,并鼓励社区贡献和用户反馈,以满足不断变化的用户需求。 # 关键字 DFLauncher;游戏管理工具;安装配置;性能优化;故障诊断;社区贡献;定制化扩展;网络功能集成 参考资源
recommend-type

银河麒麟系统打开屏保

<think>嗯,用户的问题是关于银河麒麟操作系统如何开启屏幕保护程序。这个问题其实挺实用的,毕竟长时间盯着电脑屏幕容易造成视觉疲劳,屏幕保护能起到一定的保护作用。 从用户提供的引用内容来看,引用[2]提到了银河麒麟系统的常见问题处理,包括重启输入法和双屏配置问题,虽然不直接相关,但说明用户对系统设置有一定了解。引用[3]是关于安装cmake的,可能用户之前尝试过其他系统配置操作。 用户可能是在办公环境中使用银河麒麟系统,需要保护视力。考虑到银河麒麟是基于Linux开发的国产操作系统,它的屏幕保护设置应该和主流Linux发行版类似,但可能有自己的特色功能。 我回忆了一下Linux系统的屏