暂时可以生成hdf5数据

This commit is contained in:
2026-02-17 22:20:25 +08:00
parent ba006e14c4
commit b701d939c2
3 changed files with 443 additions and 0 deletions

31
models/text_encoder.py Normal file
View File

@@ -0,0 +1,31 @@
import torch
import torch.nn as nn
class DistilBERTTextEncoder(nn.Module):
def __init__(self, model_name='distilbert-base-uncased', output_dim=768, freeze=True):
super().__init__()
try:
from transformers import DistilBertModel
except ImportError as exc:
raise ImportError(
'transformers is required for DistilBERT text encoding. '
'Install it with: pip install transformers'
) from exc
self.encoder = DistilBertModel.from_pretrained(model_name)
self.output_dim = output_dim
self.freeze = freeze
if self.freeze:
for param in self.encoder.parameters():
param.requires_grad = False
self.encoder.eval()
def forward(self, input_ids, attention_mask):
if self.freeze:
self.encoder.eval()
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
# DistilBERT has no pooled output; use [CLS] token embedding
cls_feature = outputs.last_hidden_state[:, 0, :]
return cls_feature