medcat.utils.relation_extraction.models
Module Contents
Classes
BertModel class for RelCAT |
- class medcat.utils.relation_extraction.models.BertModel_RelationExtraction(pretrained_model_name_or_path, relcat_config, model_config)
Bases:
torch.nn.Module
BertModel class for RelCAT
- Parameters:
pretrained_model_name_or_path (str) –
relcat_config (medcat.config_rel_cat.ConfigRelCAT) –
model_config (transformers.models.bert.configuration_bert.BertConfig) –
- name = 'bertmodel_relcat'
- log
- __init__(pretrained_model_name_or_path, relcat_config, model_config)
Class to hold the BERT model + model_config
- Parameters:
pretrained_model_name_or_path (str) – path to load the model from, this can be a HF model i.e: “bert-base-uncased”, if left empty, it is normally assumed that a model is loaded from ‘model.dat’ using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model.
relcat_config (ConfigRelCAT) – relcat config.
model_config (BertConfig) – HF bert config for model.
- get_annotation_schema_tag(sequence_output, input_ids, special_tag)
- Gets to token sequences from the sequence_ouput for the specific token
tag ids in self.relcat_config.general.annotation_schema_tag_ids.
- Parameters:
sequence_output (torch.Tensor) – hidden states/embeddings for each token in the input text
input_ids (torch.Tensor) – input token ids
special_tag (List) – special annotation token id pairs
- Returns:
torch.Tensor – new seq_tags
- Return type:
torch.Tensor
- output2logits(pooled_output, sequence_output, input_ids, e1_e2_start)
- Parameters:
pooled_output (torch.Tensor) – embedding of the CLS token
sequence_output (torch.Tensor) – hidden states/embeddings for each token in the input text
input_ids (torch.Tensor) – input token ids.
e1_e2_start (torch.Tensor) – annotation tags token position
- Returns:
torch.Tensor – classification probabilities for each token.
- Return type:
torch.Tensor
- forward(input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, Q=None, e1_e2_start=None, pooled_output=None)
- Parameters:
input_ids (Optional[torch.Tensor]) –
attention_mask (Optional[torch.Tensor]) –
token_type_ids (Optional[torch.Tensor]) –
position_ids (Any) –
head_mask (Any) –
encoder_hidden_states (Any) –
encoder_attention_mask (Any) –
Q (Any) –
e1_e2_start (Any) –
pooled_output (Any) –
- Return type:
Tuple[torch.Tensor, torch.Tensor]