32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class DistilBERTTextEncoder(nn.Module):
|
|
def __init__(self, model_name='distilbert-base-uncased', output_dim=768, freeze=True):
|
|
super().__init__()
|
|
try:
|
|
from transformers import DistilBertModel
|
|
except ImportError as exc:
|
|
raise ImportError(
|
|
'transformers is required for DistilBERT text encoding. '
|
|
'Install it with: pip install transformers'
|
|
) from exc
|
|
|
|
self.encoder = DistilBertModel.from_pretrained(model_name)
|
|
self.output_dim = output_dim
|
|
self.freeze = freeze
|
|
|
|
if self.freeze:
|
|
for param in self.encoder.parameters():
|
|
param.requires_grad = False
|
|
self.encoder.eval()
|
|
|
|
def forward(self, input_ids, attention_mask):
|
|
if self.freeze:
|
|
self.encoder.eval()
|
|
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
|
# DistilBERT has no pooled output; use [CLS] token embedding
|
|
cls_feature = outputs.last_hidden_state[:, 0, :]
|
|
return cls_feature
|