medcat.utils.relation_extraction.models
Module Contents
Classes
Base class for the RelCAT models |
|
Base class for the RelCAT models |
- class medcat.utils.relation_extraction.models.BaseModelBluePrint_RelationExtraction(pretrained_model_name_or_path, relcat_config, model_config)
Bases:
torch.nn.ModuleBase class for the RelCAT models
- Parameters:
pretrained_model_name_or_path (str) –
relcat_config (medcat.config_rel_cat.ConfigRelCAT) –
model_config (Union[transformers.PretrainedConfig, medcat.utils.relation_extraction.config.BaseConfig_RelationExtraction]) –
- hf_model: transformers.PreTrainedModel
- relcat_config: medcat.config_rel_cat.ConfigRelCAT
- model_config: transformers.PretrainedConfig
- drop_out: torch.nn.Dropout
- fc1: torch.nn.Linear
- fc2: torch.nn.Linear
- fc3: torch.nn.Linear
- __init__(pretrained_model_name_or_path, relcat_config, model_config)
Class to hold the HF model + model_config
- Parameters:
pretrained_model_name_or_path (str) – path to load the model from, this can be a HF model i.e: “bert-base-uncased”, if left empty, it is normally assumed that a model is loaded from ‘model.dat’ using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model.
relcat_config (ConfigRelCAT) – relcat config.
model_config (PretrainedConfig) – HF bert config for model.
- forward(input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, Q=None, e1_e2_start=None, pooled_output=None)
Forward pass for the model
- Parameters:
input_ids (torch.Tensor) – input token ids. Defaults to None.
attention_mask (torch.Tensor) – attention mask for the input ids. Defaults to None.
token_type_ids (torch.Tensor) – token type ids for the input ids. Defaults to None.
position_ids (Any) – The position IDs. Defaults to None.
head_mask (Any) – The head mask. Defaults to None.
encoder_hidden_states (Any) – Encoder hidden states. Defaults to None.
encoder_attention_mask (Any) – Encoder attention mask. Defaults to None.
Q (Any) –
Defaults to None.
e1_e2_start (Any) – start and end indices for the entities in the input ids. Defaults to None.
pooled_output (Any) – The pooled output. Defaults to None.
- Returns:
Optional[Tuple[torch.Tensor, torch.Tensor]] – logits for the relation classification task.
- Return type:
Optional[Tuple[torch.Tensor, torch.Tensor]]
- output2logits(pooled_output, sequence_output, input_ids, e1_e2_start)
Convert the output of the model to logits
- Parameters:
pooled_output (torch.Tensor) – output of the pooled layer.
sequence_output (torch.Tensor) – output of the sequence layer.
input_ids (torch.Tensor) – input token ids.
e1_e2_start (torch.Tensor) – start and end indices for the entities in the input ids.
- Returns:
logits (torch.Tensor) – logits for the relation classification task.
- Return type:
Optional[torch.Tensor]
- class medcat.utils.relation_extraction.models.BaseModel_RelationExtraction(relcat_config, model_config, pretrained_model_name_or_path)
Bases:
BaseModelBluePrint_RelationExtractionBase class for the RelCAT models
- Parameters:
relcat_config (medcat.config_rel_cat.ConfigRelCAT) –
model_config (medcat.utils.relation_extraction.config.BaseConfig_RelationExtraction) –
- name = 'basemodel_relcat'
- log
- __init__(relcat_config, model_config, pretrained_model_name_or_path)
Class to hold the HF model + model_config
- Parameters:
pretrained_model_name_or_path (str) – path to load the model from, this can be a HF model i.e: “bert-base-uncased”, if left empty, it is normally assumed that a model is loaded from ‘model.dat’ using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model.
relcat_config (ConfigRelCAT) – relcat config.
model_config (PretrainedConfig) – HF bert config for model.
- _reinitialize_dense_and_frozen_layers(relcat_config)
Reinitialize the dense layers of the model
- Parameters:
relcat_config (ConfigRelCAT) – relcat config.
- Return type:
None
- forward(input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, Q=None, e1_e2_start=None, pooled_output=None)
Forward pass for the model
- Parameters:
input_ids (torch.Tensor) – input token ids. Defaults to None.
attention_mask (torch.Tensor) – attention mask for the input ids. Defaults to None.
token_type_ids (torch.Tensor) – token type ids for the input ids. Defaults to None.
position_ids (Any) – The position IDs. Defaults to None.
head_mask (Any) – The head mask. Defaults to None.
encoder_hidden_states (Any) – Encoder hidden states. Defaults to None.
encoder_attention_mask (Any) – Encoder attention mask. Defaults to None.
Q (Any) –
Defaults to None.
e1_e2_start (Any) – start and end indices for the entities in the input ids. Defaults to None.
pooled_output (Any) – The pooled output. Defaults to None.
- Returns:
Optional[Tuple[torch.Tensor, torch.Tensor]] – logits for the relation classification task.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- output2logits(pooled_output, sequence_output, input_ids, e1_e2_start)
- Parameters:
pooled_output (torch.Tensor) – embedding of the CLS token
sequence_output (torch.Tensor) – hidden states/embeddings for each token in the input text
input_ids (torch.Tensor) – input token ids.
e1_e2_start (torch.Tensor) – annotation tags token position
- Returns:
torch.Tensor – classification probabilities for each token.
- Return type:
torch.Tensor
- classmethod load(pretrained_model_name_or_path, relcat_config, model_config)
Load the model from the given path
- Parameters:
pretrained_model_name_or_path (str) – path to load the model from.
relcat_config (ConfigRelCAT) – relcat config.
model_config (BaseConfig_RelationExtraction) – The model-specific config.
- Returns:
BaseModel_RelationExtraction – The loaded model.
- Return type: