Topology Layers

class gdeep.topology_layers.Persformer(config: PersformerConfig)

Persformer model as described in the paper: https://arxiv.org/abs/2112.15210

Examples:

from gdeep.topology_layers import PersformerConfig, PersformerModel
# Initialize the configuration object
config = PersformerConfig()
# Initialize the model
model = Persformer(config)
build_model()

Build the model.

forward(input_batch: Tensor, attention_mask: Tensor | None = None) Tensor

Forward pass of the model.

Args:

input_batch: The input batch. Of shape (batch_size, sequence_length, 2 + num_homology_types) attention_mask: The attention mask. Of shape (batch_size, sequence_length)

Returns:

The logits of the model. Of shape (batch_size, sequence_length, 1)

class gdeep.topology_layers.PersformerConfig(input_size: int = 6, output_size: int = 2, hidden_size: int = 32, num_attention_layers: int = 2, num_attention_heads: int = 4, intermediate_size: int = 32, hidden_act: ActivationFunction = ActivationFunction.GELU, hidden_dropout_prob: float = 0.1, attention_probs_dropout_prob: float = 0.1, layer_norm_eps: float = 1e-12, classifier_dropout_prob: float = 0.1, use_layer_norm: LayerNormStyle = LayerNormStyle.NO_LAYER_NORMALIZATION, attention_type: AttentionType = AttentionType.DOT_PRODUCT, pooler_type: PoolerType = PoolerType.ATTENTION, use_attention_only: bool = False, use_skip_connections_for_persformer_blocks=False, **kwargs)

Configuration class to define a persformer model.

Examples:

from gdeep.topological_layers import PersformerConfig, PersformerModel
# Initialize the configuration object
config = PersformerConfig()
# Initialize the model
model = Persformer(config)
# Access the configuration object
config = model.config
class gdeep.topology_layers.PersformerWrapper(**kwargs)

The wrapper for persformer to allow compatibility with the HPO classes.

forward(input: Tensor, attention_mask: Tensor | None = None)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Attention

Pooling Layers