SentenceBERTお試し
環境
Python 3.9.4
pip install transformers fugashi ipadic torch
transformers==4.16.2
torch==1.10.2
numpy==1.22.2
TODO:torch.Tensorの変換についてもっと理解を深めたい
code:example.py
from __future__ import annotations
from collections.abc import Sequence
import torch
from transformers import BertJapaneseTokenizer, BertModel
from transformers.modeling_outputs import (
BaseModelOutputWithPoolingAndCrossAttentions,
)
class SentenceBertJapanese:
def __init__(self, model_name_or_path: str, device: str | None = None):
self.tokenizer = BertJapaneseTokenizer.from_pretrained(
model_name_or_path
)
self.model = BertModel.from_pretrained(model_name_or_path)
self.model.eval()
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.model.to(device)
def _mean_pooling(
self,
model_output: BaseModelOutputWithPoolingAndCrossAttentions,
attention_mask: torch.Tensor,
):
# attention_mask size: (batch_size, token_length)
# 0==last_hidden_state. size: (batch_size, token_length, 768) token_embeddings = model_output0 input_mask_expanded = (
attention_mask.unsqueeze(-1) # (batch_size, token_length, 1)
.expand(token_embeddings.size()) # (batch_size, token_length, 768)
.float()
)
return torch.sum(
token_embeddings * input_mask_expanded, dim=1
) / torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
@torch.no_grad()
def encode(
self, sentences: Sequencestr, batch_size: int = 8 ) -> torch.Tensor:
all_embeddings = []
for batch_idx in range(0, len(sentences), batch_size):
# 3つのキー input_ids, token_type_ids, attention_mask からなる辞書が返る
encoded_input = self.tokenizer.batch_encode_plus(
batch, padding="longest", truncation=True, return_tensors="pt"
).to(self.device)
model_output = self.model(**encoded_input)
sentence_embeddings = self._mean_pooling(
)
all_embeddings.extend(sentence_embeddings)
return torch.stack(all_embeddings)
def calc_cosine_similarity(
tensor1: torch.Tensor, tensor2: torch.Tensor
) -> float:
return torch.cosine_similarity(tensor1, tensor2, dim=0).item()
if __name__ == "__main__":
MODEL_NAME = "sonoisa/sentence-bert-base-ja-mean-tokens-v2"
model = SentenceBertJapanese(MODEL_NAME)
sentence_embeddings = model.encode(sentences, batch_size=8)
sentences2 = [
"大学構内では喫煙禁止です。",
"大学でタバコを吸うのはダメです。",
"今日は学校でタバコを買った。",
]
sentence_embeddings2 = model.encode(sentences2, batch_size=8)
code:python
$ python -i example.py
>> calc_cosine_similarity(sentence_embeddings0, sentence_embeddings1) 0.48751819133758545
>> calc_cosine_similarity(sentence_embeddings1, sentence_embeddings0) 0.48751819133758545
# sentence20と1の類似度 > sentence21と2の類似度なのはよさそう(0と1だけが意味的に似ている) >> calc_cosine_similarity(sentence_embeddings20, sentence_embeddings21) 0.6689098477363586
>> calc_cosine_similarity(sentence_embeddings20, sentence_embeddings22) 0.3702680468559265
>> calc_cosine_similarity(sentence_embeddings21, sentence_embeddings22) 0.41787001490592957