rlify.models package
rlify.models.base_model module
- class rlify.models.base_model.BaseModel(input_shape, out_shape)
Bases:
Module
,ABC
Base class for all NN models
- is_rnn = False
- __init__(input_shape, out_shape)
- Parameters:
input_shape (tuple) – input shape of the model
out_shape (tuple) – output shape of the model
- get_total_params()
Returns the total number of parameters in the model
- abstract forward(x)
Forward pass of the model
- abstract reset()
resets hidden state, if not rnn - can be just: pass
- property device
Returns the device of the model
- __abstractmethods__ = frozenset({'forward', 'reset'})
- __annotations__ = {'__call__': 'Callable[..., Any]', '_backward_hooks': 'Dict[int, Callable]', '_backward_pre_hooks': 'Dict[int, Callable]', '_buffers': 'Dict[str, Optional[Tensor]]', '_compiled_call_impl': 'Optional[Callable]', '_forward_hooks': 'Dict[int, Callable]', '_forward_hooks_always_called': 'Dict[int, bool]', '_forward_hooks_with_kwargs': 'Dict[int, bool]', '_forward_pre_hooks': 'Dict[int, Callable]', '_forward_pre_hooks_with_kwargs': 'Dict[int, bool]', '_is_full_backward_hook': 'Optional[bool]', '_load_state_dict_post_hooks': 'Dict[int, Callable]', '_load_state_dict_pre_hooks': 'Dict[int, Callable]', '_modules': "Dict[str, Optional['Module']]", '_non_persistent_buffers_set': 'Set[str]', '_parameters': 'Dict[str, Optional[Parameter]]', '_state_dict_hooks': 'Dict[int, Callable]', '_state_dict_pre_hooks': 'Dict[int, Callable]', '_version': 'int', 'call_super_init': 'bool', 'dump_patches': 'bool', 'forward': 'Callable[..., Any]', 'training': 'bool'}
- __module__ = 'rlify.models.base_model'
- _abc_impl = <_abc._abc_data object>
rlify.models.fc module
- class rlify.models.fc.FC(embed_dim=64, depth=2, activation=ReLU(), *args, **kwargs)
Bases:
BaseModel
A Basic fully connected model
- __init__(embed_dim=64, depth=2, activation=ReLU(), *args, **kwargs)
- Parameters:
embed_dim – int: the embedding dimension
depth – int: the depth of the model
activation – torch.nn.Module: the activation function
*args – args: args to pass to the base class
**kwargs – kwargs: kwargs to pass to the base class
- forward(x)
Forward pass of the model
- reset()
resets hidden state, if not rnn - can be just: pass
- __abstractmethods__ = frozenset({})
- __annotations__ = {'__call__': 'Callable[..., Any]', '_backward_hooks': 'Dict[int, Callable]', '_backward_pre_hooks': 'Dict[int, Callable]', '_buffers': 'Dict[str, Optional[Tensor]]', '_compiled_call_impl': 'Optional[Callable]', '_forward_hooks': 'Dict[int, Callable]', '_forward_hooks_always_called': 'Dict[int, bool]', '_forward_hooks_with_kwargs': 'Dict[int, bool]', '_forward_pre_hooks': 'Dict[int, Callable]', '_forward_pre_hooks_with_kwargs': 'Dict[int, bool]', '_is_full_backward_hook': 'Optional[bool]', '_load_state_dict_post_hooks': 'Dict[int, Callable]', '_load_state_dict_pre_hooks': 'Dict[int, Callable]', '_modules': "Dict[str, Optional['Module']]", '_non_persistent_buffers_set': 'Set[str]', '_parameters': 'Dict[str, Optional[Parameter]]', '_state_dict_hooks': 'Dict[int, Callable]', '_state_dict_pre_hooks': 'Dict[int, Callable]', '_version': 'int', 'call_super_init': 'bool', 'dump_patches': 'bool', 'forward': 'Callable[..., Any]', 'training': 'bool'}
- __module__ = 'rlify.models.fc'
- _abc_impl = <_abc._abc_data object>
rlify.models.model_factory module
rlify.models.rnn module
- class rlify.models.rnn.ReccurentLayer(input_shape, out_shape)
Bases:
BaseModel
Base class for RNNs
- is_rnn = True
- __init__(input_shape, out_shape)
- Parameters:
input_shape (tuple) – input shape of the model
out_shape (tuple) – output shape of the model
- forward(x)
Forward pass of the model
- Parameters:
x – PackedSequence: the input data
- reset()
resets hidden state, if not rnn - can be just: pass
- __abstractmethods__ = frozenset({})
- __annotations__ = {'__call__': 'Callable[..., Any]', '_backward_hooks': 'Dict[int, Callable]', '_backward_pre_hooks': 'Dict[int, Callable]', '_buffers': 'Dict[str, Optional[Tensor]]', '_compiled_call_impl': 'Optional[Callable]', '_forward_hooks': 'Dict[int, Callable]', '_forward_hooks_always_called': 'Dict[int, bool]', '_forward_hooks_with_kwargs': 'Dict[int, bool]', '_forward_pre_hooks': 'Dict[int, Callable]', '_forward_pre_hooks_with_kwargs': 'Dict[int, bool]', '_is_full_backward_hook': 'Optional[bool]', '_load_state_dict_post_hooks': 'Dict[int, Callable]', '_load_state_dict_pre_hooks': 'Dict[int, Callable]', '_modules': "Dict[str, Optional['Module']]", '_non_persistent_buffers_set': 'Set[str]', '_parameters': 'Dict[str, Optional[Parameter]]', '_state_dict_hooks': 'Dict[int, Callable]', '_state_dict_pre_hooks': 'Dict[int, Callable]', '_version': 'int', 'call_super_init': 'bool', 'dump_patches': 'bool', 'forward': 'Callable[..., Any]', 'training': 'bool'}
- __module__ = 'rlify.models.rnn'
- _abc_impl = <_abc._abc_data object>
- class rlify.models.rnn.GRU(hidden_dim=64, num_grus=2, *args, **kwargs)
Bases:
ReccurentLayer
GRU model
- __init__(hidden_dim=64, num_grus=2, *args, **kwargs)
- Parameters:
hidden_dim – int: the hidden dimension
num_grus – int: the number of GRUs
*args – args: args to pass to the base class
**kwargs – kwargs: kwargs to pass to the base
- __abstractmethods__ = frozenset({})
- __annotations__ = {'__call__': 'Callable[..., Any]', '_backward_hooks': 'Dict[int, Callable]', '_backward_pre_hooks': 'Dict[int, Callable]', '_buffers': 'Dict[str, Optional[Tensor]]', '_compiled_call_impl': 'Optional[Callable]', '_forward_hooks': 'Dict[int, Callable]', '_forward_hooks_always_called': 'Dict[int, bool]', '_forward_hooks_with_kwargs': 'Dict[int, bool]', '_forward_pre_hooks': 'Dict[int, Callable]', '_forward_pre_hooks_with_kwargs': 'Dict[int, bool]', '_is_full_backward_hook': 'Optional[bool]', '_load_state_dict_post_hooks': 'Dict[int, Callable]', '_load_state_dict_pre_hooks': 'Dict[int, Callable]', '_modules': "Dict[str, Optional['Module']]", '_non_persistent_buffers_set': 'Set[str]', '_parameters': 'Dict[str, Optional[Parameter]]', '_state_dict_hooks': 'Dict[int, Callable]', '_state_dict_pre_hooks': 'Dict[int, Callable]', '_version': 'int', 'call_super_init': 'bool', 'dump_patches': 'bool', 'forward': 'Callable[..., Any]', 'training': 'bool'}
- __module__ = 'rlify.models.rnn'
- _abc_impl = <_abc._abc_data object>
- forward(x)
Forward pass of the model
- Parameters:
x (
Tensor
) – PackedSequence: the input data
- reset()
resets hidden state, if not rnn - can be just: pass