:py:mod:`medcat.utils.relation_extraction.models` ================================================= .. py:module:: medcat.utils.relation_extraction.models Module Contents --------------- Classes ~~~~~~~ .. autoapisummary:: medcat.utils.relation_extraction.models.BertModel_RelationExtraction .. py:class:: BertModel_RelationExtraction(pretrained_model_name_or_path, relcat_config, model_config) Bases: :py:obj:`torch.nn.Module` BertModel class for RelCAT .. py:attribute:: name :value: 'bertmodel_relcat' .. py:attribute:: log .. py:method:: __init__(pretrained_model_name_or_path, relcat_config, model_config) Class to hold the BERT model + model_config :param pretrained_model_name_or_path: path to load the model from, this can be a HF model i.e: "bert-base-uncased", if left empty, it is normally assumed that a model is loaded from 'model.dat' using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model. :type pretrained_model_name_or_path: str :param relcat_config: relcat config. :type relcat_config: ConfigRelCAT :param model_config: HF bert config for model. :type model_config: BertConfig .. py:method:: get_annotation_schema_tag(sequence_output, input_ids, special_tag) Gets to token sequences from the sequence_ouput for the specific token tag ids in self.relcat_config.general.annotation_schema_tag_ids. :param sequence_output: hidden states/embeddings for each token in the input text :type sequence_output: torch.Tensor :param input_ids: input token ids :type input_ids: torch.Tensor :param special_tag: special annotation token id pairs :type special_tag: List :Returns: **torch.Tensor** -- new seq_tags .. py:method:: output2logits(pooled_output, sequence_output, input_ids, e1_e2_start) :param pooled_output: embedding of the CLS token :type pooled_output: torch.Tensor :param sequence_output: hidden states/embeddings for each token in the input text :type sequence_output: torch.Tensor :param input_ids: input token ids. :type input_ids: torch.Tensor :param e1_e2_start: annotation tags token position :type e1_e2_start: torch.Tensor :Returns: **torch.Tensor** -- classification probabilities for each token. .. py:method:: forward(input_ids = None, attention_mask = None, token_type_ids = None, position_ids = None, head_mask = None, encoder_hidden_states = None, encoder_attention_mask = None, Q = None, e1_e2_start = None, pooled_output = None)