rlify.models package

rlify.models.base_model module

class rlify.models.base_model.BaseModel(input_shape, out_shape)

Bases: Module, ABC

Base class for all NN models

is_rnn = False
__init__(input_shape, out_shape)
Parameters:
  • input_shape (tuple) – input shape of the model

  • out_shape (tuple) – output shape of the model

get_total_params()

Returns the total number of parameters in the model

abstract forward(x)

Forward pass of the model

abstract reset()

resets hidden state, if not rnn - can be just: pass

property device

Returns the device of the model

__abstractmethods__ = frozenset({'forward', 'reset'})
__annotations__ = {'__call__': 'Callable[..., Any]', '_backward_hooks': 'Dict[int, Callable]', '_backward_pre_hooks': 'Dict[int, Callable]', '_buffers': 'Dict[str, Optional[Tensor]]', '_compiled_call_impl': 'Optional[Callable]', '_forward_hooks': 'Dict[int, Callable]', '_forward_hooks_always_called': 'Dict[int, bool]', '_forward_hooks_with_kwargs': 'Dict[int, bool]', '_forward_pre_hooks': 'Dict[int, Callable]', '_forward_pre_hooks_with_kwargs': 'Dict[int, bool]', '_is_full_backward_hook': 'Optional[bool]', '_load_state_dict_post_hooks': 'Dict[int, Callable]', '_load_state_dict_pre_hooks': 'Dict[int, Callable]', '_modules': "Dict[str, Optional['Module']]", '_non_persistent_buffers_set': 'Set[str]', '_parameters': 'Dict[str, Optional[Parameter]]', '_state_dict_hooks': 'Dict[int, Callable]', '_state_dict_pre_hooks': 'Dict[int, Callable]', '_version': 'int', 'call_super_init': 'bool', 'dump_patches': 'bool', 'forward': 'Callable[..., Any]', 'training': 'bool'}
__module__ = 'rlify.models.base_model'
_abc_impl = <_abc._abc_data object>

rlify.models.fc module

class rlify.models.fc.FC(embed_dim=64, depth=2, activation=ReLU(), *args, **kwargs)

Bases: BaseModel

A Basic fully connected model

__init__(embed_dim=64, depth=2, activation=ReLU(), *args, **kwargs)
Parameters:
  • embed_dim – int: the embedding dimension

  • depth – int: the depth of the model

  • activation – torch.nn.Module: the activation function

  • *args – args: args to pass to the base class

  • **kwargs – kwargs: kwargs to pass to the base class

forward(x)

Forward pass of the model

reset()

resets hidden state, if not rnn - can be just: pass

__abstractmethods__ = frozenset({})
__annotations__ = {'__call__': 'Callable[..., Any]', '_backward_hooks': 'Dict[int, Callable]', '_backward_pre_hooks': 'Dict[int, Callable]', '_buffers': 'Dict[str, Optional[Tensor]]', '_compiled_call_impl': 'Optional[Callable]', '_forward_hooks': 'Dict[int, Callable]', '_forward_hooks_always_called': 'Dict[int, bool]', '_forward_hooks_with_kwargs': 'Dict[int, bool]', '_forward_pre_hooks': 'Dict[int, Callable]', '_forward_pre_hooks_with_kwargs': 'Dict[int, bool]', '_is_full_backward_hook': 'Optional[bool]', '_load_state_dict_post_hooks': 'Dict[int, Callable]', '_load_state_dict_pre_hooks': 'Dict[int, Callable]', '_modules': "Dict[str, Optional['Module']]", '_non_persistent_buffers_set': 'Set[str]', '_parameters': 'Dict[str, Optional[Parameter]]', '_state_dict_hooks': 'Dict[int, Callable]', '_state_dict_pre_hooks': 'Dict[int, Callable]', '_version': 'int', 'call_super_init': 'bool', 'dump_patches': 'bool', 'forward': 'Callable[..., Any]', 'training': 'bool'}
__module__ = 'rlify.models.fc'
_abc_impl = <_abc._abc_data object>

rlify.models.model_factory module

rlify.models.rnn module

class rlify.models.rnn.ReccurentLayer(input_shape, out_shape)

Bases: BaseModel

Base class for RNNs

is_rnn = True
__init__(input_shape, out_shape)
Parameters:
  • input_shape (tuple) – input shape of the model

  • out_shape (tuple) – output shape of the model

forward(x)

Forward pass of the model

Parameters:

x – PackedSequence: the input data

reset()

resets hidden state, if not rnn - can be just: pass

__abstractmethods__ = frozenset({})
__annotations__ = {'__call__': 'Callable[..., Any]', '_backward_hooks': 'Dict[int, Callable]', '_backward_pre_hooks': 'Dict[int, Callable]', '_buffers': 'Dict[str, Optional[Tensor]]', '_compiled_call_impl': 'Optional[Callable]', '_forward_hooks': 'Dict[int, Callable]', '_forward_hooks_always_called': 'Dict[int, bool]', '_forward_hooks_with_kwargs': 'Dict[int, bool]', '_forward_pre_hooks': 'Dict[int, Callable]', '_forward_pre_hooks_with_kwargs': 'Dict[int, bool]', '_is_full_backward_hook': 'Optional[bool]', '_load_state_dict_post_hooks': 'Dict[int, Callable]', '_load_state_dict_pre_hooks': 'Dict[int, Callable]', '_modules': "Dict[str, Optional['Module']]", '_non_persistent_buffers_set': 'Set[str]', '_parameters': 'Dict[str, Optional[Parameter]]', '_state_dict_hooks': 'Dict[int, Callable]', '_state_dict_pre_hooks': 'Dict[int, Callable]', '_version': 'int', 'call_super_init': 'bool', 'dump_patches': 'bool', 'forward': 'Callable[..., Any]', 'training': 'bool'}
__module__ = 'rlify.models.rnn'
_abc_impl = <_abc._abc_data object>
class rlify.models.rnn.GRU(hidden_dim=64, num_grus=2, *args, **kwargs)

Bases: ReccurentLayer

GRU model

__init__(hidden_dim=64, num_grus=2, *args, **kwargs)
Parameters:
  • hidden_dim – int: the hidden dimension

  • num_grus – int: the number of GRUs

  • *args – args: args to pass to the base class

  • **kwargs – kwargs: kwargs to pass to the base

__abstractmethods__ = frozenset({})
__annotations__ = {'__call__': 'Callable[..., Any]', '_backward_hooks': 'Dict[int, Callable]', '_backward_pre_hooks': 'Dict[int, Callable]', '_buffers': 'Dict[str, Optional[Tensor]]', '_compiled_call_impl': 'Optional[Callable]', '_forward_hooks': 'Dict[int, Callable]', '_forward_hooks_always_called': 'Dict[int, bool]', '_forward_hooks_with_kwargs': 'Dict[int, bool]', '_forward_pre_hooks': 'Dict[int, Callable]', '_forward_pre_hooks_with_kwargs': 'Dict[int, bool]', '_is_full_backward_hook': 'Optional[bool]', '_load_state_dict_post_hooks': 'Dict[int, Callable]', '_load_state_dict_pre_hooks': 'Dict[int, Callable]', '_modules': "Dict[str, Optional['Module']]", '_non_persistent_buffers_set': 'Set[str]', '_parameters': 'Dict[str, Optional[Parameter]]', '_state_dict_hooks': 'Dict[int, Callable]', '_state_dict_pre_hooks': 'Dict[int, Callable]', '_version': 'int', 'call_super_init': 'bool', 'dump_patches': 'bool', 'forward': 'Callable[..., Any]', 'training': 'bool'}
__module__ = 'rlify.models.rnn'
_abc_impl = <_abc._abc_data object>
forward(x)

Forward pass of the model

Parameters:

x (Tensor) – PackedSequence: the input data

reset()

resets hidden state, if not rnn - can be just: pass