import torch import torch.nn as nn class DistilBERTTextEncoder(nn.Module): def __init__(self, model_name='distilbert-base-uncased', output_dim=768, freeze=True): super().__init__() try: from transformers import DistilBertModel except ImportError as exc: raise ImportError( 'transformers is required for DistilBERT text encoding. ' 'Install it with: pip install transformers' ) from exc self.encoder = DistilBertModel.from_pretrained(model_name) self.output_dim = output_dim self.freeze = freeze if self.freeze: for param in self.encoder.parameters(): param.requires_grad = False self.encoder.eval() def forward(self, input_ids, attention_mask): if self.freeze: self.encoder.eval() outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) # DistilBERT has no pooled output; use [CLS] token embedding cls_feature = outputs.last_hidden_state[:, 0, :] return cls_feature