Topology Layers
- class gdeep.topology_layers.Persformer(config: PersformerConfig)
Persformer model as described in the paper: https://arxiv.org/abs/2112.15210
Examples:
from gdeep.topology_layers import PersformerConfig, PersformerModel # Initialize the configuration object config = PersformerConfig() # Initialize the model model = Persformer(config)
- build_model()
Build the model.
- forward(input_batch: Tensor, attention_mask: Tensor | None = None) Tensor
Forward pass of the model.
- Args:
input_batch: The input batch. Of shape (batch_size, sequence_length, 2 + num_homology_types) attention_mask: The attention mask. Of shape (batch_size, sequence_length)
- Returns:
The logits of the model. Of shape (batch_size, sequence_length, 1)
- class gdeep.topology_layers.PersformerConfig(input_size: int = 6, output_size: int = 2, hidden_size: int = 32, num_attention_layers: int = 2, num_attention_heads: int = 4, intermediate_size: int = 32, hidden_act: ActivationFunction = ActivationFunction.GELU, hidden_dropout_prob: float = 0.1, attention_probs_dropout_prob: float = 0.1, layer_norm_eps: float = 1e-12, classifier_dropout_prob: float = 0.1, use_layer_norm: LayerNormStyle = LayerNormStyle.NO_LAYER_NORMALIZATION, attention_type: AttentionType = AttentionType.DOT_PRODUCT, pooler_type: PoolerType = PoolerType.ATTENTION, use_attention_only: bool = False, use_skip_connections_for_persformer_blocks=False, **kwargs)
Configuration class to define a persformer model.
Examples:
from gdeep.topological_layers import PersformerConfig, PersformerModel # Initialize the configuration object config = PersformerConfig() # Initialize the model model = Persformer(config) # Access the configuration object config = model.config
- class gdeep.topology_layers.PersformerWrapper(**kwargs)
The wrapper for persformer to allow compatibility with the HPO classes.
- forward(input: Tensor, attention_mask: Tensor | None = None)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.