暂时可以生成hdf5数据
This commit is contained in:
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
31
models/text_encoder.py
Normal file
31
models/text_encoder.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class DistilBERTTextEncoder(nn.Module):
|
||||
def __init__(self, model_name='distilbert-base-uncased', output_dim=768, freeze=True):
|
||||
super().__init__()
|
||||
try:
|
||||
from transformers import DistilBertModel
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
'transformers is required for DistilBERT text encoding. '
|
||||
'Install it with: pip install transformers'
|
||||
) from exc
|
||||
|
||||
self.encoder = DistilBertModel.from_pretrained(model_name)
|
||||
self.output_dim = output_dim
|
||||
self.freeze = freeze
|
||||
|
||||
if self.freeze:
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
self.encoder.eval()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
if self.freeze:
|
||||
self.encoder.eval()
|
||||
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
||||
# DistilBERT has no pooled output; use [CLS] token embedding
|
||||
cls_feature = outputs.last_hidden_state[:, 0, :]
|
||||
return cls_feature
|
||||
Reference in New Issue
Block a user