medcat.utils.relation_extraction.models

Module Contents

Classes

BaseModelBluePrint_RelationExtraction

Base class for the RelCAT models

BaseModel_RelationExtraction

Base class for the RelCAT models

class medcat.utils.relation_extraction.models.BaseModelBluePrint_RelationExtraction(pretrained_model_name_or_path, relcat_config, model_config)

Bases: torch.nn.Module

Base class for the RelCAT models

Parameters:
hf_model: transformers.PreTrainedModel
relcat_config: medcat.config_rel_cat.ConfigRelCAT
model_config: transformers.PretrainedConfig
drop_out: torch.nn.Dropout
fc1: torch.nn.Linear
fc2: torch.nn.Linear
fc3: torch.nn.Linear
__init__(pretrained_model_name_or_path, relcat_config, model_config)

Class to hold the HF model + model_config

Parameters:
  • pretrained_model_name_or_path (str) – path to load the model from, this can be a HF model i.e: “bert-base-uncased”, if left empty, it is normally assumed that a model is loaded from ‘model.dat’ using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model.

  • relcat_config (ConfigRelCAT) – relcat config.

  • model_config (PretrainedConfig) – HF bert config for model.

forward(input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, Q=None, e1_e2_start=None, pooled_output=None)

Forward pass for the model

Parameters:
  • input_ids (torch.Tensor) – input token ids. Defaults to None.

  • attention_mask (torch.Tensor) – attention mask for the input ids. Defaults to None.

  • token_type_ids (torch.Tensor) – token type ids for the input ids. Defaults to None.

  • position_ids (Any) – The position IDs. Defaults to None.

  • head_mask (Any) – The head mask. Defaults to None.

  • encoder_hidden_states (Any) – Encoder hidden states. Defaults to None.

  • encoder_attention_mask (Any) – Encoder attention mask. Defaults to None.

  • Q (Any) –

    1. Defaults to None.

  • e1_e2_start (Any) – start and end indices for the entities in the input ids. Defaults to None.

  • pooled_output (Any) – The pooled output. Defaults to None.

Returns:

Optional[Tuple[torch.Tensor, torch.Tensor]] – logits for the relation classification task.

Return type:

Optional[Tuple[torch.Tensor, torch.Tensor]]

output2logits(pooled_output, sequence_output, input_ids, e1_e2_start)

Convert the output of the model to logits

Parameters:
  • pooled_output (torch.Tensor) – output of the pooled layer.

  • sequence_output (torch.Tensor) – output of the sequence layer.

  • input_ids (torch.Tensor) – input token ids.

  • e1_e2_start (torch.Tensor) – start and end indices for the entities in the input ids.

Returns:

logits (torch.Tensor) – logits for the relation classification task.

Return type:

Optional[torch.Tensor]

class medcat.utils.relation_extraction.models.BaseModel_RelationExtraction(relcat_config, model_config, pretrained_model_name_or_path)

Bases: BaseModelBluePrint_RelationExtraction

Base class for the RelCAT models

Parameters:
name = 'basemodel_relcat'
log
__init__(relcat_config, model_config, pretrained_model_name_or_path)

Class to hold the HF model + model_config

Parameters:
  • pretrained_model_name_or_path (str) – path to load the model from, this can be a HF model i.e: “bert-base-uncased”, if left empty, it is normally assumed that a model is loaded from ‘model.dat’ using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model.

  • relcat_config (ConfigRelCAT) – relcat config.

  • model_config (PretrainedConfig) – HF bert config for model.

_reinitialize_dense_and_frozen_layers(relcat_config)

Reinitialize the dense layers of the model

Parameters:

relcat_config (ConfigRelCAT) – relcat config.

Return type:

None

forward(input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, Q=None, e1_e2_start=None, pooled_output=None)

Forward pass for the model

Parameters:
  • input_ids (torch.Tensor) – input token ids. Defaults to None.

  • attention_mask (torch.Tensor) – attention mask for the input ids. Defaults to None.

  • token_type_ids (torch.Tensor) – token type ids for the input ids. Defaults to None.

  • position_ids (Any) – The position IDs. Defaults to None.

  • head_mask (Any) – The head mask. Defaults to None.

  • encoder_hidden_states (Any) – Encoder hidden states. Defaults to None.

  • encoder_attention_mask (Any) – Encoder attention mask. Defaults to None.

  • Q (Any) –

    1. Defaults to None.

  • e1_e2_start (Any) – start and end indices for the entities in the input ids. Defaults to None.

  • pooled_output (Any) – The pooled output. Defaults to None.

Returns:

Optional[Tuple[torch.Tensor, torch.Tensor]] – logits for the relation classification task.

Return type:

Tuple[torch.Tensor, torch.Tensor]

output2logits(pooled_output, sequence_output, input_ids, e1_e2_start)
Parameters:
  • pooled_output (torch.Tensor) – embedding of the CLS token

  • sequence_output (torch.Tensor) – hidden states/embeddings for each token in the input text

  • input_ids (torch.Tensor) – input token ids.

  • e1_e2_start (torch.Tensor) – annotation tags token position

Returns:

torch.Tensor – classification probabilities for each token.

Return type:

torch.Tensor

classmethod load(pretrained_model_name_or_path, relcat_config, model_config)

Load the model from the given path

Parameters:
Returns:

BaseModel_RelationExtraction – The loaded model.

Return type:

BaseModel_RelationExtraction