:py:mod:`medcat.utils.relation_extraction.models` ================================================= .. py:module:: medcat.utils.relation_extraction.models Module Contents --------------- Classes ~~~~~~~ .. autoapisummary:: medcat.utils.relation_extraction.models.BaseModelBluePrint_RelationExtraction medcat.utils.relation_extraction.models.BaseModel_RelationExtraction .. py:class:: BaseModelBluePrint_RelationExtraction(pretrained_model_name_or_path, relcat_config, model_config) Bases: :py:obj:`torch.nn.Module` Base class for the RelCAT models .. py:attribute:: hf_model :type: transformers.PreTrainedModel .. py:attribute:: relcat_config :type: medcat.config_rel_cat.ConfigRelCAT .. py:attribute:: model_config :type: transformers.PretrainedConfig .. py:attribute:: drop_out :type: torch.nn.Dropout .. py:attribute:: fc1 :type: torch.nn.Linear .. py:attribute:: fc2 :type: torch.nn.Linear .. py:attribute:: fc3 :type: torch.nn.Linear .. py:method:: __init__(pretrained_model_name_or_path, relcat_config, model_config) Class to hold the HF model + model_config :param pretrained_model_name_or_path: path to load the model from, this can be a HF model i.e: "bert-base-uncased", if left empty, it is normally assumed that a model is loaded from 'model.dat' using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model. :type pretrained_model_name_or_path: str :param relcat_config: relcat config. :type relcat_config: ConfigRelCAT :param model_config: HF bert config for model. :type model_config: PretrainedConfig .. py:method:: forward(input_ids = None, attention_mask = None, token_type_ids = None, position_ids = None, head_mask = None, encoder_hidden_states = None, encoder_attention_mask = None, Q = None, e1_e2_start = None, pooled_output = None) Forward pass for the model :param input_ids: input token ids. Defaults to None. :type input_ids: torch.Tensor :param attention_mask: attention mask for the input ids. Defaults to None. :type attention_mask: torch.Tensor :param token_type_ids: token type ids for the input ids. Defaults to None. :type token_type_ids: torch.Tensor :param position_ids: The position IDs. Defaults to None. :type position_ids: Any :param head_mask: The head mask. Defaults to None. :type head_mask: Any :param encoder_hidden_states: Encoder hidden states. Defaults to None. :type encoder_hidden_states: Any :param encoder_attention_mask: Encoder attention mask. Defaults to None. :type encoder_attention_mask: Any :param Q: Q. Defaults to None. :type Q: Any :param e1_e2_start: start and end indices for the entities in the input ids. Defaults to None. :type e1_e2_start: Any :param pooled_output: The pooled output. Defaults to None. :type pooled_output: Any :Returns: **Optional[Tuple[torch.Tensor, torch.Tensor]]** -- logits for the relation classification task. .. py:method:: output2logits(pooled_output, sequence_output, input_ids, e1_e2_start) Convert the output of the model to logits :param pooled_output: output of the pooled layer. :type pooled_output: torch.Tensor :param sequence_output: output of the sequence layer. :type sequence_output: torch.Tensor :param input_ids: input token ids. :type input_ids: torch.Tensor :param e1_e2_start: start and end indices for the entities in the input ids. :type e1_e2_start: torch.Tensor :Returns: **logits** (*torch.Tensor*) -- logits for the relation classification task. .. py:class:: BaseModel_RelationExtraction(relcat_config, model_config, pretrained_model_name_or_path) Bases: :py:obj:`BaseModelBluePrint_RelationExtraction` Base class for the RelCAT models .. py:attribute:: name :value: 'basemodel_relcat' .. py:attribute:: log .. py:method:: __init__(relcat_config, model_config, pretrained_model_name_or_path) Class to hold the HF model + model_config :param pretrained_model_name_or_path: path to load the model from, this can be a HF model i.e: "bert-base-uncased", if left empty, it is normally assumed that a model is loaded from 'model.dat' using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model. :type pretrained_model_name_or_path: str :param relcat_config: relcat config. :type relcat_config: ConfigRelCAT :param model_config: HF bert config for model. :type model_config: PretrainedConfig .. py:method:: _reinitialize_dense_and_frozen_layers(relcat_config) Reinitialize the dense layers of the model :param relcat_config: relcat config. :type relcat_config: ConfigRelCAT .. py:method:: forward(input_ids = None, attention_mask = None, token_type_ids = None, position_ids = None, head_mask = None, encoder_hidden_states = None, encoder_attention_mask = None, Q = None, e1_e2_start = None, pooled_output = None) Forward pass for the model :param input_ids: input token ids. Defaults to None. :type input_ids: torch.Tensor :param attention_mask: attention mask for the input ids. Defaults to None. :type attention_mask: torch.Tensor :param token_type_ids: token type ids for the input ids. Defaults to None. :type token_type_ids: torch.Tensor :param position_ids: The position IDs. Defaults to None. :type position_ids: Any :param head_mask: The head mask. Defaults to None. :type head_mask: Any :param encoder_hidden_states: Encoder hidden states. Defaults to None. :type encoder_hidden_states: Any :param encoder_attention_mask: Encoder attention mask. Defaults to None. :type encoder_attention_mask: Any :param Q: Q. Defaults to None. :type Q: Any :param e1_e2_start: start and end indices for the entities in the input ids. Defaults to None. :type e1_e2_start: Any :param pooled_output: The pooled output. Defaults to None. :type pooled_output: Any :Returns: **Optional[Tuple[torch.Tensor, torch.Tensor]]** -- logits for the relation classification task. .. py:method:: output2logits(pooled_output, sequence_output, input_ids, e1_e2_start) :param pooled_output: embedding of the CLS token :type pooled_output: torch.Tensor :param sequence_output: hidden states/embeddings for each token in the input text :type sequence_output: torch.Tensor :param input_ids: input token ids. :type input_ids: torch.Tensor :param e1_e2_start: annotation tags token position :type e1_e2_start: torch.Tensor :Returns: **torch.Tensor** -- classification probabilities for each token. .. py:method:: load(pretrained_model_name_or_path, relcat_config, model_config) :classmethod: Load the model from the given path :param pretrained_model_name_or_path: path to load the model from. :type pretrained_model_name_or_path: str :param relcat_config: relcat config. :type relcat_config: ConfigRelCAT :param model_config: The model-specific config. :type model_config: BaseConfig_RelationExtraction :returns: **BaseModel_RelationExtraction** -- The loaded model.