diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..2c5c29d --- /dev/null +++ b/.env.example @@ -0,0 +1,12 @@ +CONN_URI=sqlite+aiosqlite:///data/db.sqlite3 +DEBUG=True +PROJECT_URL=http://127.0.0.1 +PROJECT_LOGIN_SUCCESS_URL=http://google.com +PROJECT_PORT=80 +JWT_PRIVATE_KEY='data/private_key' +JWT_PUBLIC_KEY='data/public_key' +BOT_TOKEN=xxx +BOT_USERNAME=xxxxBot +BOT_API_ID=111 +BOT_API_HASH=aaa +BOT_MANAGER_IDS=[111,222] \ No newline at end of file diff --git a/.gitignore b/.gitignore index 68bc17f..e08e503 100644 --- a/.gitignore +++ b/.gitignore @@ -157,4 +157,6 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ + +data/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..c54138c --- /dev/null +++ b/README.md @@ -0,0 +1,32 @@ +# Telegram OAuth + +## Configuration + +```dotenv +CONN_URI=sqlite+aiosqlite:///data/db.sqlite3 # 数据库 uri +DEBUG=True # 调试模式 +PROJECT_URL=http://127.0.0.1 # 项目可访问的地址 +PROJECT_LOGIN_SUCCESS_URL=http://google.com # 登录成功后跳转的地址 +PROJECT_PORT=80 # 项目运行的端口 +JWT_PRIVATE_KEY='data/private_key' # jwt 私钥 +JWT_PUBLIC_KEY='data/public_key' # jwt 公钥 +BOT_TOKEN=xxx # 机器人 token +BOT_USERNAME=xxxxBot # 机器人用户名 +BOT_API_ID=111 # api id +BOT_API_HASH=aaa # api hash +BOT_MANAGER_IDS=[111,222] # 管理员 id +``` + +## OIDC Endpoints + +Auth URL : `/oauth2/authorize` + +Token URL : `/oauth2/token` + +Cert URL : `/oauth2/keys` + +## OIDC Client + +```sql +INSERT INTO "client" ("grant_types", "response_types", "redirect_uris", "id", "client_id", "client_secret", "scope") VALUES ('authorization_code', 'code', 'https://127.0.0.1/access/callback', 'UUID', '123456', '123456', 'openid profile email'); +``` diff --git a/aioauth_fastapi/__init__.py b/aioauth_fastapi/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aioauth_fastapi/__version__.py b/aioauth_fastapi/__version__.py new file mode 100644 index 0000000..d5650b1 --- /dev/null +++ b/aioauth_fastapi/__version__.py @@ -0,0 +1,8 @@ +__title__ = "aioauth_fastapi" +__description__ = "aioauth integration for FastAPI." +__url__ = "https://github.com/aliev/aioauth-fastapi" +__version__ = "0.1.2" +__author__ = "Ali Aliyev" +__author_email__ = "ali@aliev.me" +__license__ = "The MIT License (MIT)" +__copyright__ = "Copyright 2021 Ali Aliyev" diff --git a/aioauth_fastapi/forms.py b/aioauth_fastapi/forms.py new file mode 100644 index 0000000..40a818b --- /dev/null +++ b/aioauth_fastapi/forms.py @@ -0,0 +1,38 @@ +""" +.. code-block:: python + + from aioauth_fastapi import forms + +FastAPI oauth2 forms. + +Used to generate an OpenAPI schema. + +---- +""" + +from dataclasses import dataclass +from typing import Optional + +from aioauth.types import GrantType, TokenType +from fastapi.params import Form + + +@dataclass +class TokenForm: + grant_type: Optional[GrantType] = Form(None) # type: ignore + client_id: Optional[str] = Form(None) # type: ignore + client_secret: Optional[str] = Form(None) # type: ignore + redirect_uri: Optional[str] = Form(None) # type: ignore + scope: Optional[str] = Form(None) # type: ignore + username: Optional[str] = Form(None) # type: ignore + password: Optional[str] = Form(None) # type: ignore + refresh_token: Optional[str] = Form(None) # type: ignore + code: Optional[str] = Form(None) # type: ignore + token: Optional[str] = Form(None) # type: ignore + code_verifier: Optional[str] = Form(None) # type: ignore + + +@dataclass +class TokenIntrospectForm: + token: Optional[str] = Form(None) # type: ignore + token_type_hint: Optional[TokenType] = Form(None) # type: ignore diff --git a/aioauth_fastapi/router.py b/aioauth_fastapi/router.py new file mode 100644 index 0000000..be3c3b0 --- /dev/null +++ b/aioauth_fastapi/router.py @@ -0,0 +1,113 @@ +""" +.. code-block:: python + + from aioauth_fastapi import router + +FastAPI routing of oauth2. + +Usage example + +.. code-block:: python + + from aioauth_fastapi.router import get_oauth2_router + from aioauth.storage import BaseStorage + from aioauth.config import Settings + from aioauth.server import AuthorizationServer + from fastapi import FastAPI + + app = FastAPI() + + class SQLAlchemyCRUD(BaseStorage): + ''' + SQLAlchemyCRUD methods must be implemented here. + ''' + + # NOTE: Redefinition of the default aioauth settings + # INSECURE_TRANSPORT must be enabled for local development only! + settings = Settings( + INSECURE_TRANSPORT=True, + ) + + storage = SQLAlchemyCRUD() + authorization_server = AuthorizationServer(storage) + + # Include FastAPI router with oauth2 endpoints. + app.include_router( + get_oauth2_router(authorization_server, settings), + prefix="/oauth2", + tags=["oauth2"], + ) + +---- +""" + +from typing import Callable, TypeVar + +from aioauth.config import Settings +from aioauth.requests import TRequest +from aioauth.server import AuthorizationServer +from aioauth.storage import TStorage +from fastapi import APIRouter, Request + +from .utils import ( + RequestArguments, + default_request_factory, + to_fastapi_response, + to_oauth2_request, +) + +ARequest = TypeVar("ARequest", bound=TRequest) + + +def get_oauth2_router( + authorization_server: AuthorizationServer[ARequest, TStorage], + settings: Settings = Settings(), + request_factory: Callable[[RequestArguments], ARequest] = default_request_factory, +) -> APIRouter: + """Function will create FastAPI router with the following oauth2 endpoints: + + * POST /token + * Endpoint creates a token response by :py:meth:`aioauth.server.AuthorizationServer.create_token_response` + * POST `/token/introspect` + * Endpoint creates a token introspection by :py:meth:`aioauth.server.AuthorizationServer.create_token_introspection_response` + * GET `/authorize` + * Endpoint creates an authorization response by :py:meth:`aioauth.server.AuthorizationServer.create_authorization_response` + + Returns: + :py:class:`fastapi.APIRouter`. + """ + router = APIRouter() + + @router.post("/token") + async def token(request: Request): + oauth2_request = await to_oauth2_request( + request=request, request_factory=request_factory, settings=settings + ) + oauth2_response = await authorization_server.create_token_response( + oauth2_request + ) + return await to_fastapi_response(oauth2_response) + + @router.post("/token/introspect") + async def token_introspect(request: Request): + oauth2_request = await to_oauth2_request( + request=request, request_factory=request_factory, settings=settings + ) + oauth2_response = ( + await authorization_server.create_token_introspection_response( + oauth2_request + ) + ) + return await to_fastapi_response(oauth2_response) + + @router.get("/authorize") + async def authorize(request: Request): + oauth2_request = await to_oauth2_request( + request=request, request_factory=request_factory, settings=settings + ) + oauth2_response = await authorization_server.create_authorization_response( + oauth2_request + ) + return await to_fastapi_response(oauth2_response) + + return router diff --git a/aioauth_fastapi/utils.py b/aioauth_fastapi/utils.py new file mode 100644 index 0000000..788932e --- /dev/null +++ b/aioauth_fastapi/utils.py @@ -0,0 +1,98 @@ +""" +.. code-block:: python + + from aioauth_fastapi import utils + +Core utils for integration with FastAPI + +---- +""" + +import json +from dataclasses import dataclass +from typing import Callable, Dict, Optional + +from aioauth.collections import HTTPHeaderDict +from aioauth.config import Settings +from aioauth.requests import Post, Query, TRequest, TUser +from aioauth.requests import Request as OAuth2Request +from aioauth.responses import Response as OAuth2Response +from fastapi import Request, Response + + +@dataclass +class RequestArguments: + headers: HTTPHeaderDict + method: str + post_args: Dict + query_args: Dict + settings: Settings + url: str + user: Optional[TUser] + + +def default_request_factory(request_args: RequestArguments) -> OAuth2Request: + return OAuth2Request( + headers=request_args.headers, + method=request_args.method, # type: ignore + post=Post(**request_args.post_args), # type: ignore + query=Query(**request_args.query_args), # type: ignore + settings=request_args.settings, + url=request_args.url, + user=request_args.user, + ) + + +async def to_oauth2_request( + request: Request, + settings: Settings = Settings(), + request_factory: Callable[[RequestArguments], TRequest] = default_request_factory, +) -> TRequest: + """Converts :py:class:`fastapi.Request` instance to :py:class:`aioauth.requests.Request` instance""" + form = await request.form() + + post_args = dict(form) + query_args = dict(request.query_params) + need_args = [ + "client_id", + "redirect_uri", + "response_type", + "state", + "scope", + "nonce", + "code_challenge_method", + "code_challenge", + "response_mode", + ] + for arg in list(query_args.keys()): + if arg not in need_args: + del query_args[arg] + method = request.method + headers = HTTPHeaderDict(**request.headers) + url = str(request.url) + + user = None + + if request.user.is_authenticated: + user = request.user + + request_args = RequestArguments( + headers=headers, + method=method, + post_args=post_args, + query_args=query_args, + settings=settings, + url=url, + user=user, + ) + return request_factory(request_args) + + +async def to_fastapi_response(oauth2_response: OAuth2Response) -> Response: + """Converts :py:class:`aioauth.responses.Response` instance to :py:class:`fastapi.Response` instance""" + response_content = oauth2_response.content + headers = dict(oauth2_response.headers) + status_code = oauth2_response.status_code + content = json.dumps(response_content) + + return Response(content=content, headers=headers, status_code=status_code) diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..436dba5 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,89 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic/ + +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date +# within the migration file as well as the filename. +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; this defaults +# to ./versions. When using multiple version +# directories, initial revisions must be specified with --version-path +# version_locations = %(here)s/bar %(here)s/bat ./versions + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = sqlite+aiosqlite:///data/db.sqlite3 + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..28bc9b5 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,90 @@ +import asyncio +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import engine_from_config, pool +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlmodel import SQLModel + +from src.config import settings +from src.oauth2.models import * # noqa +from src.users.models import * # noqa + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) + +config.set_main_option("sqlalchemy.url", str(settings.CONN_URI)) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = SQLModel.metadata # type: ignore + + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline(): + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection): + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_migrations_online(): + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = AsyncEngine( + engine_from_config( + config.get_section(config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + future=True, + ) + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + asyncio.run(run_migrations_online()) diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..e45068f --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,25 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/07a7ace268a7_initial_migrations.py b/alembic/versions/07a7ace268a7_initial_migrations.py new file mode 100644 index 0000000..ab7550d --- /dev/null +++ b/alembic/versions/07a7ace268a7_initial_migrations.py @@ -0,0 +1,226 @@ +"""Initial migrations + +Revision ID: 07a7ace268a7 +Revises: +Create Date: 2021-10-02 22:50:10.418498 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "07a7ace268a7" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "users", + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("is_superuser", sa.Boolean(), nullable=True), + sa.Column("is_blocked", sa.Boolean(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=True), + sa.Column("username", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("password", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_users_id"), "users", ["id"], unique=True) + op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False) + op.create_index(op.f("ix_users_is_blocked"), "users", ["is_blocked"], unique=False) + op.create_index( + op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False + ) + op.create_index(op.f("ix_users_password"), "users", ["password"], unique=False) + op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True) + op.create_table( + "authorizationcode", + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("code", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("client_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("redirect_uri", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("response_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("scope", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("auth_time", sa.Integer(), nullable=False), + sa.Column("expires_in", sa.Integer(), nullable=False), + sa.Column("code_challenge", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column( + "code_challenge_method", sqlmodel.sql.sqltypes.AutoString(), nullable=True + ), + sa.Column("nonce", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_authorizationcode_auth_time"), + "authorizationcode", + ["auth_time"], + unique=False, + ) + op.create_index( + op.f("ix_authorizationcode_client_id"), + "authorizationcode", + ["client_id"], + unique=False, + ) + op.create_index( + op.f("ix_authorizationcode_code"), "authorizationcode", ["code"], unique=False + ) + op.create_index( + op.f("ix_authorizationcode_code_challenge"), + "authorizationcode", + ["code_challenge"], + unique=False, + ) + op.create_index( + op.f("ix_authorizationcode_code_challenge_method"), + "authorizationcode", + ["code_challenge_method"], + unique=False, + ) + op.create_index( + op.f("ix_authorizationcode_expires_in"), + "authorizationcode", + ["expires_in"], + unique=False, + ) + op.create_index( + op.f("ix_authorizationcode_id"), "authorizationcode", ["id"], unique=True + ) + op.create_index( + op.f("ix_authorizationcode_nonce"), "authorizationcode", ["nonce"], unique=False + ) + op.create_index( + op.f("ix_authorizationcode_redirect_uri"), + "authorizationcode", + ["redirect_uri"], + unique=False, + ) + op.create_index( + op.f("ix_authorizationcode_response_type"), + "authorizationcode", + ["response_type"], + unique=False, + ) + op.create_index( + op.f("ix_authorizationcode_scope"), "authorizationcode", ["scope"], unique=False + ) + op.create_index( + op.f("ix_authorizationcode_user_id"), + "authorizationcode", + ["user_id"], + unique=False, + ) + op.create_table( + "client", + sa.Column("grant_types", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("response_types", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("redirect_uris", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("client_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("client_secret", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("scope", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_client_client_id"), "client", ["client_id"], unique=False) + op.create_index( + op.f("ix_client_client_secret"), "client", ["client_secret"], unique=False + ) + op.create_index(op.f("ix_client_id"), "client", ["id"], unique=True) + op.create_index(op.f("ix_client_scope"), "client", ["scope"], unique=False) + op.create_table( + "token", + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("access_token", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("refresh_token", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("scope", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("issued_at", sa.Integer(), nullable=False), + sa.Column("expires_in", sa.Integer(), nullable=False), + sa.Column("refresh_token_expires_in", sa.Integer(), nullable=False), + sa.Column("client_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("token_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("revoked", sa.Boolean(), nullable=False), + sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_token_client_id"), "token", ["client_id"], unique=False) + op.create_index(op.f("ix_token_expires_in"), "token", ["expires_in"], unique=False) + op.create_index(op.f("ix_token_id"), "token", ["id"], unique=True) + op.create_index(op.f("ix_token_issued_at"), "token", ["issued_at"], unique=False) + op.create_index( + op.f("ix_token_refresh_token_expires_in"), + "token", + ["refresh_token_expires_in"], + unique=False, + ) + op.create_index(op.f("ix_token_revoked"), "token", ["revoked"], unique=False) + op.create_index(op.f("ix_token_scope"), "token", ["scope"], unique=False) + op.create_index(op.f("ix_token_token_type"), "token", ["token_type"], unique=False) + op.create_index(op.f("ix_token_user_id"), "token", ["user_id"], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_token_user_id"), table_name="token") + op.drop_index(op.f("ix_token_token_type"), table_name="token") + op.drop_index(op.f("ix_token_scope"), table_name="token") + op.drop_index(op.f("ix_token_revoked"), table_name="token") + op.drop_index(op.f("ix_token_refresh_token_expires_in"), table_name="token") + op.drop_index(op.f("ix_token_issued_at"), table_name="token") + op.drop_index(op.f("ix_token_id"), table_name="token") + op.drop_index(op.f("ix_token_expires_in"), table_name="token") + op.drop_index(op.f("ix_token_client_id"), table_name="token") + op.drop_table("token") + op.drop_index(op.f("ix_client_scope"), table_name="client") + op.drop_index(op.f("ix_client_id"), table_name="client") + op.drop_index(op.f("ix_client_client_secret"), table_name="client") + op.drop_index(op.f("ix_client_client_id"), table_name="client") + op.drop_table("client") + op.drop_index(op.f("ix_authorizationcode_user_id"), table_name="authorizationcode") + op.drop_index(op.f("ix_authorizationcode_scope"), table_name="authorizationcode") + op.drop_index( + op.f("ix_authorizationcode_response_type"), table_name="authorizationcode" + ) + op.drop_index( + op.f("ix_authorizationcode_redirect_uri"), table_name="authorizationcode" + ) + op.drop_index(op.f("ix_authorizationcode_nonce"), table_name="authorizationcode") + op.drop_index(op.f("ix_authorizationcode_id"), table_name="authorizationcode") + op.drop_index( + op.f("ix_authorizationcode_expires_in"), table_name="authorizationcode" + ) + op.drop_index( + op.f("ix_authorizationcode_code_challenge_method"), + table_name="authorizationcode", + ) + op.drop_index( + op.f("ix_authorizationcode_code_challenge"), table_name="authorizationcode" + ) + op.drop_index(op.f("ix_authorizationcode_code"), table_name="authorizationcode") + op.drop_index( + op.f("ix_authorizationcode_client_id"), table_name="authorizationcode" + ) + op.drop_index( + op.f("ix_authorizationcode_auth_time"), table_name="authorizationcode" + ) + op.drop_table("authorizationcode") + op.drop_index(op.f("ix_users_username"), table_name="users") + op.drop_index(op.f("ix_users_password"), table_name="users") + op.drop_index(op.f("ix_users_is_superuser"), table_name="users") + op.drop_index(op.f("ix_users_is_blocked"), table_name="users") + op.drop_index(op.f("ix_users_is_active"), table_name="users") + op.drop_index(op.f("ix_users_id"), table_name="users") + op.drop_table("users") + # ### end Alembic commands ### diff --git a/alembic/versions/c76c4cbb0b3b_tgid.py b/alembic/versions/c76c4cbb0b3b_tgid.py new file mode 100644 index 0000000..2df3e5c --- /dev/null +++ b/alembic/versions/c76c4cbb0b3b_tgid.py @@ -0,0 +1,27 @@ +"""tgid + +Revision ID: c76c4cbb0b3b +Revises: 07a7ace268a7 +Create Date: 2024-01-13 16:12:45.884304 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'c76c4cbb0b3b' +down_revision = '07a7ace268a7' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('users', sa.Column('tg_id', sa.BigInteger(), nullable=False)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('users', 'tg_id') + # ### end Alembic commands ### diff --git a/gen_keys.py b/gen_keys.py new file mode 100644 index 0000000..a24edb1 --- /dev/null +++ b/gen_keys.py @@ -0,0 +1,26 @@ +from pathlib import Path + +data_path = Path("data") +data_path.mkdir(exist_ok=True) +private_key_path = data_path / "private_key" +public_key_path = data_path / "public_key" + + +def gen_keys(): + from Crypto.PublicKey import RSA + + key = RSA.generate(2048) + private_key = key.export_key().decode("utf-8") + public_key = key.publickey().export_key().decode("utf-8") + + if private_key_path.is_file() and public_key_path.is_file(): + print("Keys already exist") + return + with open(private_key_path, "w") as f: + f.write(private_key) + with open(public_key_path, "w") as f: + f.write(public_key) + + +if __name__ == "__main__": + gen_keys() diff --git a/html/login.jinja b/html/login.jinja new file mode 100644 index 0000000..208dced --- /dev/null +++ b/html/login.jinja @@ -0,0 +1,40 @@ + + + + + Title + + + + +
+ +
+ + + +
+
+ + \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..112d8b5 --- /dev/null +++ b/main.py @@ -0,0 +1,50 @@ +import asyncio +from signal import signal as signal_fn, SIGINT, SIGTERM, SIGABRT + +from src.app import web +from src.bot import bot +from src.logs import logs + + +async def idle(): + task = None + + def signal_handler(_, __): + if web.web_server_task: + web.web_server_task.cancel() + task.cancel() + + for s in (SIGINT, SIGTERM, SIGABRT): + signal_fn(s, signal_handler) + + while True: + task = asyncio.create_task(asyncio.sleep(600)) + web.bot_main_task = task + try: + await task + except asyncio.CancelledError: + break + + +async def main(): + logs.info("正在启动 Web Server") + await web.start() + logs.info("正在启动 Bot") + await bot.start() + try: + logs.info("正在运行") + await idle() + finally: + try: + await bot.stop() + except ConnectionError: + pass + if web.web_server: + try: + await web.web_server.shutdown() + except AttributeError: + pass + + +if __name__ == "__main__": + bot.run(main()) diff --git a/pyromod/__init__.py b/pyromod/__init__.py new file mode 100644 index 0000000..c22821d --- /dev/null +++ b/pyromod/__init__.py @@ -0,0 +1,21 @@ +""" +pyromod - A monkeypatched add-on for Pyrogram +Copyright (C) 2020 Cezar H. + +This file is part of pyromod. + +pyromod is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +pyromod is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with pyromod. If not, see . +""" + +__version__ = "1.5" diff --git a/pyromod/listen/__init__.py b/pyromod/listen/__init__.py new file mode 100644 index 0000000..b916072 --- /dev/null +++ b/pyromod/listen/__init__.py @@ -0,0 +1,21 @@ +""" +pyromod - A monkeypatcher add-on for Pyrogram +Copyright (C) 2020 Cezar H. + +This file is part of pyromod. + +pyromod is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +pyromod is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with pyromod. If not, see . +""" + +from .listen import Client, MessageHandler, Chat, User diff --git a/pyromod/listen/listen.py b/pyromod/listen/listen.py new file mode 100644 index 0000000..a443885 --- /dev/null +++ b/pyromod/listen/listen.py @@ -0,0 +1,157 @@ +""" +pyromod - A monkeypatcher add-on for Pyrogram +Copyright (C) 2020 Cezar H. + +This file is part of pyromod. + +pyromod is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +pyromod is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with pyromod. If not, see . +""" + +import asyncio +import functools + +import pyrogram + +from src.scheduler import add_delete_message_job +from ..utils import patch, patchable +from ..utils.errors import ListenerCanceled, TimeoutConversationError + +pyrogram.errors.ListenerCanceled = ListenerCanceled + + +@patch(pyrogram.client.Client) +class Client: + @patchable + def __init__(self, *args, **kwargs): + self.listening = {} + self.using_mod = True + + self.old__init__(*args, **kwargs) + + @patchable + async def listen(self, chat_id, filters=None, timeout=None): + if type(chat_id) != int: + chat = await self.get_chat(chat_id) + chat_id = chat.id + + future = self.loop.create_future() + future.add_done_callback(functools.partial(self.clear_listener, chat_id)) + self.listening.update({chat_id: {"future": future, "filters": filters}}) + try: + return await asyncio.wait_for(future, timeout) + except asyncio.exceptions.TimeoutError as e: + raise TimeoutConversationError() from e + + @patchable + async def ask(self, chat_id, text, filters=None, timeout=None, *args, **kwargs): + request = await self.send_message(chat_id, text, *args, **kwargs) + response = await self.listen(chat_id, filters, timeout) + response.request = request + return response + + @patchable + def clear_listener(self, chat_id, future): + if future == self.listening[chat_id]["future"]: + self.listening.pop(chat_id, None) + + @patchable + def cancel_listener(self, chat_id): + listener = self.listening.get(chat_id) + if not listener or listener["future"].done(): + return + + listener["future"].set_exception(ListenerCanceled()) + self.clear_listener(chat_id, listener["future"]) + + @patchable + def cancel_all_listener(self): + for chat_id in self.listening: + self.cancel_listener(chat_id) + + +@patch(pyrogram.handlers.message_handler.MessageHandler) +class MessageHandler: + @patchable + def __init__(self, callback: callable, filters=None): + self.user_callback = callback + self.old__init__(self.resolve_listener, filters) + + @patchable + async def resolve_listener(self, client, message, *args): + listener = client.listening.get(message.chat.id) + if listener and not listener["future"].done(): + listener["future"].set_result(message) + else: + if listener and listener["future"].done(): + client.clear_listener(message.chat.id, listener["future"]) + await self.user_callback(client, message, *args) + + @patchable + async def check(self, client, update): + listener = client.listening.get(update.chat.id) + + if listener and not listener["future"].done(): + return ( + await listener["filters"](client, update) + if callable(listener["filters"]) + else True + ) + + return await self.filters(client, update) if callable(self.filters) else True + + +@patch(pyrogram.types.user_and_chats.chat.Chat) +class Chat(pyrogram.types.Chat): + @patchable + def listen(self, *args, **kwargs): + return self._client.listen(self.id, *args, **kwargs) + + @patchable + def ask(self, *args, **kwargs): + return self._client.ask(self.id, *args, **kwargs) + + @patchable + def cancel_listener(self): + return self._client.cancel_listener(self.id) + + +@patch(pyrogram.types.user_and_chats.user.User) +class User(pyrogram.types.User): + @patchable + def listen(self, *args, **kwargs): + return self._client.listen(self.id, *args, **kwargs) + + @patchable + def ask(self, *args, **kwargs): + return self._client.ask(self.id, *args, **kwargs) + + @patchable + def cancel_listener(self): + return self._client.cancel_listener(self.id) + + +@patch(pyrogram.types.messages_and_media.Message) +class Message(pyrogram.types.Message): + @patchable + async def safe_delete(self, revoke: bool = True): + try: + return await self._client.delete_messages( + chat_id=self.chat.id, message_ids=self.id, revoke=revoke + ) + except Exception as e: # noqa + return False + + @patchable + async def delay_delete(self, delay: int = 60): + add_delete_message_job(self, delay) diff --git a/pyromod/utils/__init__.py b/pyromod/utils/__init__.py new file mode 100644 index 0000000..3c9f81a --- /dev/null +++ b/pyromod/utils/__init__.py @@ -0,0 +1,21 @@ +""" +pyromod - A monkeypatcher add-on for Pyrogram +Copyright (C) 2020 Cezar H. + +This file is part of pyromod. + +pyromod is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +pyromod is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with pyromod. If not, see . +""" + +from .utils import patch, patchable diff --git a/pyromod/utils/errors.py b/pyromod/utils/errors.py new file mode 100644 index 0000000..d72ab10 --- /dev/null +++ b/pyromod/utils/errors.py @@ -0,0 +1,16 @@ +class TimeoutConversationError(Exception): + """ + Occurs when the conversation times out. + """ + + def __init__(self): + super().__init__("Response read timed out") + + +class ListenerCanceled(Exception): + """ + Occurs when the listener is canceled. + """ + + def __init__(self): + super().__init__("Listener was canceled") diff --git a/pyromod/utils/utils.py b/pyromod/utils/utils.py new file mode 100644 index 0000000..2202260 --- /dev/null +++ b/pyromod/utils/utils.py @@ -0,0 +1,38 @@ +""" +pyromod - A monkeypatcher add-on for Pyrogram +Copyright (C) 2020 Cezar H. + +This file is part of pyromod. + +pyromod is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +pyromod is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with pyromod. If not, see . +""" + + +def patch(obj): + def is_patchable(item): + return getattr(item[1], "patchable", False) + + def wrapper(container): + for name, func in filter(is_patchable, container.__dict__.items()): + old = getattr(obj, name, None) + setattr(obj, f"old{name}", old) + setattr(obj, name, func) + return container + + return wrapper + + +def patchable(func): + func.patchable = True + return func diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4324cfb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,22 @@ +fastapi==0.109.0 +uvicorn==0.25.0 +git+https://github.com/aliev/aioauth +sqlmodel==0.0.14 +alembic==1.13.1 +aiosqlite==0.19.0 +PyCryptodome==3.20.0 +python-jose[cryptography]==3.3.0 +python-multipart==0.0.6 +orjson==3.9.10 +jinja2==3.1.3 +pydantic~=2.5.3 +pydantic-settings==2.1.0 +SQLAlchemy~=2.0.25 +starlette~=0.35.1 +pyrogram==2.0.106 +tgcrypto==1.2.5 +pytz~=2023.3.post1 +APScheduler~=3.10.4 +coloredlogs~=15.0.1 +httpx==0.26.0 +asyncmy==0.2.9 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app.py b/src/app.py new file mode 100644 index 0000000..b13e50d --- /dev/null +++ b/src/app.py @@ -0,0 +1,87 @@ +import asyncio + +from fastapi import FastAPI +from fastapi.responses import ORJSONResponse +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.middleware.cors import CORSMiddleware + +from .config import settings +from .events import on_shutdown, on_startup +from .logs import logs +from .oauth2 import endpoints as oauth2_endpoints +from .users import endpoints as users_endpoints +from .users.backends import TokenAuthenticationBackend + + +class Web: + def __init__(self): + self.app = FastAPI( + title=settings.PROJECT_NAME, + docs_url="/api/openapi", + openapi_url="/api/openapi.json", + default_response_class=ORJSONResponse, + on_startup=on_startup, + on_shutdown=on_shutdown, + ) + self.web_server = None + self.web_server_task = None + self.bot_main_task = None + + def init_web(self): + # Include API router + self.app.include_router(users_endpoints.router, prefix="/api/users", tags=["users"]) + + # Define aioauth-fastapi endpoints + self.app.include_router( + oauth2_endpoints.router, + prefix="/oauth2", + tags=["oauth2"], + ) + self.app.add_middleware( + CORSMiddleware, + allow_origins=settings.CORS_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + self.app.add_middleware(AuthenticationMiddleware, backend=TokenAuthenticationBackend()) + + async def start(self): + import uvicorn + + self.init_web() + self.web_server = uvicorn.Server( + config=uvicorn.Config( + self.app, + host=settings.PROJECT_HOST, + port=settings.PROJECT_PORT, + reload=settings.DEBUG, + ) + ) + server_config = self.web_server.config + server_config.setup_event_loop() + if not server_config.loaded: + server_config.load() + self.web_server.lifespan = server_config.lifespan_class(server_config) + try: + await self.web_server.startup() + except OSError as e: + if e.errno == 10048: + logs.error("Web Server 端口被占用:%s", e) + logs.error("Web Server 启动失败,正在退出") + raise SystemExit from None + + if self.web_server.should_exit: + logs.error("Web Server 启动失败,正在退出") + raise SystemExit from None + logs.info("Web Server 启动成功") + self.web_server_task = asyncio.create_task(self.web_server.main_loop()) + + async def stop(self): + if self.web_server_task: + self.web_server_task.cancel() + if self.bot_main_task: + self.bot_main_task.cancel() + + +web = Web() diff --git a/src/bot.py b/src/bot.py new file mode 100644 index 0000000..3c46665 --- /dev/null +++ b/src/bot.py @@ -0,0 +1,13 @@ +import pyromod.listen +from pyrogram import Client + +from .config import settings + +bot = Client( + "bot", + bot_token=settings.BOT_TOKEN, + api_id=settings.BOT_API_ID, + api_hash=settings.BOT_API_HASH, + plugins={"root": "src.telegram.plugins"}, + workdir="data", +) diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..2b50888 --- /dev/null +++ b/src/config.py @@ -0,0 +1,35 @@ +from typing import List + +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + PROJECT_NAME: str = "Telegram OAuth" + PROJECT_URL: str = "http://127.0.0.1:8081" + PROJECT_LOGIN_SUCCESS_URL: str = "http://google.com" + PROJECT_HOST: str = "127.0.0.1" + PROJECT_PORT: int = 8001 + DEBUG: bool = True + + CONN_URI: str + + JWT_PUBLIC_KEY: str + JWT_PRIVATE_KEY: str + + ACCESS_TOKEN_EXP: int = 900 # 15 minutes + REFRESH_TOKEN_EXP: int = 86400 # 1 day + + CORS_ORIGINS: List[str] = ["*"] + + BOT_TOKEN: str + BOT_USERNAME: str + BOT_API_ID: int + BOT_API_HASH: str + BOT_MANAGER_IDS: List[int] + + class Config: + env_file = ".env" + case_sensitive = True + + +settings = Settings() diff --git a/src/events.py b/src/events.py new file mode 100644 index 0000000..d88c603 --- /dev/null +++ b/src/events.py @@ -0,0 +1,27 @@ +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import NullPool +from sqlmodel.ext.asyncio.session import AsyncSession + +from .config import settings +from .storage import sqlalchemy + + +async def create_sqlalchemy_connection(): + # NOTE: https://docs.sqlalchemy.org/en/14/orm/extensions/asyncio.html#using-multiple-asyncio-event-loops + engine = create_async_engine(settings.CONN_URI, echo=True, poolclass=NullPool) + async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) + sqlalchemy.sqlalchemy_session = async_session() + + +async def close_sqlalchemy_connection(): + if sqlalchemy.sqlalchemy_session is not None: + await sqlalchemy.sqlalchemy_session.close() + + +on_startup = [ + create_sqlalchemy_connection, +] +on_shutdown = [ + close_sqlalchemy_connection, +] diff --git a/src/html.py b/src/html.py new file mode 100644 index 0000000..1c1448a --- /dev/null +++ b/src/html.py @@ -0,0 +1,3 @@ +from starlette.templating import Jinja2Templates + +templates = Jinja2Templates(directory="html") diff --git a/src/logs.py b/src/logs.py new file mode 100644 index 0000000..61a3d8f --- /dev/null +++ b/src/logs.py @@ -0,0 +1,21 @@ +from logging import getLogger, StreamHandler, basicConfig, INFO, CRITICAL, ERROR + +from coloredlogs import ColoredFormatter + +logs = getLogger("telegram-oauth") +logging_format = "%(levelname)s [%(asctime)s] [%(name)s] %(message)s" +logging_handler = StreamHandler() +logging_handler.setFormatter(ColoredFormatter(logging_format)) +root_logger = getLogger() +root_logger.setLevel(CRITICAL) +root_logger.addHandler(logging_handler) +pyro_logger = getLogger("pyrogram") +pyro_logger.setLevel(INFO) +sql_logger = getLogger("sqlalchemy") +sql_logger.setLevel(CRITICAL) +sql_engine_logger = getLogger("sqlalchemy.engine.Engine") +sql_engine_logger.setLevel(CRITICAL) +aioauth_logger = getLogger("aioauth") +aioauth_logger.setLevel(INFO) +basicConfig(level=ERROR) +logs.setLevel(INFO) diff --git a/src/oauth2/__init__.py b/src/oauth2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/oauth2/endpoints.py b/src/oauth2/endpoints.py new file mode 100644 index 0000000..4acb3ce --- /dev/null +++ b/src/oauth2/endpoints.py @@ -0,0 +1,56 @@ +from aioauth.config import Settings +from aioauth.oidc.core.grant_type import AuthorizationCodeGrantType +from aioauth.requests import Request as OAuth2Request +from aioauth.server import AuthorizationServer +from fastapi import APIRouter, Depends, Request + +from aioauth_fastapi.utils import to_fastapi_response, to_oauth2_request +from .storage import Storage +from ..config import settings as local_settings +from ..storage.sqlalchemy import SQLAlchemyStorage, get_sqlalchemy_storage +from ..users.crypto import get_pub_key_resp +from ..utils.oauth import to_login_request + +router = APIRouter() + +settings = Settings( + TOKEN_EXPIRES_IN=local_settings.ACCESS_TOKEN_EXP, + REFRESH_TOKEN_EXPIRES_IN=local_settings.REFRESH_TOKEN_EXP, + INSECURE_TRANSPORT=local_settings.DEBUG, +) +grant_types = { + "authorization_code": AuthorizationCodeGrantType, +} + + +@router.post("/token") +async def token( + request: Request, + storage: SQLAlchemyStorage = Depends(get_sqlalchemy_storage), +): + oauth2_storage = Storage(storage=storage) + authorization_server = AuthorizationServer(storage=oauth2_storage, grant_types=grant_types) + oauth2_request: OAuth2Request = await to_oauth2_request(request, settings) + oauth2_response = await authorization_server.create_token_response(oauth2_request) + return await to_fastapi_response(oauth2_response) + + +@router.get("/authorize") +async def authorize( + request: Request, + storage: SQLAlchemyStorage = Depends(get_sqlalchemy_storage), +): + if not request.user.is_authenticated: + return await to_login_request(request) + oauth2_storage = Storage(storage=storage) + authorization_server = AuthorizationServer(storage=oauth2_storage, grant_types=grant_types) + oauth2_request: OAuth2Request = await to_oauth2_request(request, settings) + oauth2_response = await authorization_server.create_authorization_response( + oauth2_request + ) + return await to_fastapi_response(oauth2_response) + + +@router.get("/keys") +async def keys(): + return get_pub_key_resp() diff --git a/src/oauth2/models.py b/src/oauth2/models.py new file mode 100644 index 0000000..4a16aad --- /dev/null +++ b/src/oauth2/models.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING, Optional + +from pydantic.types import UUID4 +from sqlmodel.main import Field, Relationship + +from ..storage.models import BaseTable + +if TYPE_CHECKING: # pragma: no cover + from ..users.models import User + + +class Client(BaseTable, table=True): # type: ignore + client_id: str + client_secret: str + grant_types: str + response_types: str + redirect_uris: str + + scope: str + + +class AuthorizationCode(BaseTable, table=True): # type: ignore + code: str + client_id: str + redirect_uri: str + response_type: str + scope: str + auth_time: int + expires_in: int + code_challenge: Optional[str] + code_challenge_method: Optional[str] + nonce: Optional[str] + + user_id: UUID4 = Field(foreign_key="users.id", nullable=False) + user: "User" = Relationship(back_populates="user_authorization_codes") + + +class Token(BaseTable, table=True): # type: ignore + access_token: str + refresh_token: str + scope: str + issued_at: int + expires_in: int + refresh_token_expires_in: int + client_id: str + token_type: str + revoked: bool + + user_id: UUID4 = Field(foreign_key="users.id", nullable=False) + user: "User" = Relationship(back_populates="user_tokens") diff --git a/src/oauth2/storage.py b/src/oauth2/storage.py new file mode 100644 index 0000000..d008a2a --- /dev/null +++ b/src/oauth2/storage.py @@ -0,0 +1,289 @@ +from datetime import datetime, timezone +from typing import Optional + +from aioauth.models import AuthorizationCode, Client, Token +from aioauth.requests import Request +from aioauth.storage import BaseStorage +from aioauth.types import CodeChallengeMethod, ResponseType, TokenType +from aioauth.utils import enforce_list +from sqlalchemy.future import select +from sqlalchemy.orm import selectinload +from sqlalchemy.sql.expression import delete + +from .models import AuthorizationCode as AuthorizationCodeDB +from .models import Client as ClientDB +from .models import Token as TokenDB +from ..config import settings +from ..storage.sqlalchemy import SQLAlchemyStorage +from ..users.crypto import encode_jwt, get_jwt +from ..users.models import User + + +class Storage(BaseStorage): + def __init__(self, storage: SQLAlchemyStorage): + self.storage = storage + + async def get_user(self, request: Request): + user: Optional[User] = None + + if request.query.response_type == "token": + # If ResponseType is token get the user from current session + user = request.user + + if request.post.grant_type == "authorization_code": + # If GrantType is authorization code get user from DB by code + q_results = await self.storage.select( + select(AuthorizationCodeDB).where( + AuthorizationCodeDB.code == request.post.code + ) + ) + + authorization_code: Optional[AuthorizationCodeDB] + authorization_code = q_results.scalars().one_or_none() + + if not authorization_code: + return + + q_results = await self.storage.select( + select(User).where(User.id == authorization_code.user_id) + ) + + user = q_results.scalars().one_or_none() + + if request.post.grant_type == "refresh_token": + # Get user from token + q_results = await self.storage.select( + select(TokenDB) + .where(TokenDB.refresh_token == request.post.refresh_token) + .options(selectinload(TokenDB.user)) + ) + + token: Optional[TokenDB] + + token = q_results.scalars().one_or_none() + + if not token: + return + + user = token.user + + return user + + async def create_token( + self, + request: Request, + client_id: str, + scope: str, + access_token: str, + refresh_token: str, + ) -> Token: + """ + Create token and store it in storage. + """ + user = await self.get_user(request) + + _access_token, _refresh_token = get_jwt(user) + + token = Token( + access_token=_access_token, + client_id=client_id, + expires_in=300, + issued_at=int(datetime.now(tz=timezone.utc).timestamp()), + refresh_token=_refresh_token, + refresh_token_expires_in=900, + revoked=False, + scope=scope, + token_type="Bearer", + user=user, + ) + + token_record = TokenDB( + access_token=token.access_token, + refresh_token=token.refresh_token, + scope=token.scope, + issued_at=token.issued_at, + expires_in=token.expires_in, + refresh_token_expires_in=token.refresh_token_expires_in, + client_id=token.client_id, + token_type=token.token_type, + revoked=token.revoked, + user_id=user.id, + ) + + await self.storage.add(token_record) + + return token + + async def revoke_token( + self, + request: Request, + token_type: Optional[TokenType] = "refresh_token", + access_token: Optional[str] = None, + refresh_token: Optional[str] = None, + ) -> None: + """ + Remove refresh_token from whitelist. + """ + q_results = await self.storage.select( + select(TokenDB).where(TokenDB.refresh_token == refresh_token) + ) + token_record: Optional[TokenDB] + token_record = q_results.scalars().one_or_none() + + if token_record: + token_record.revoked = True + await self.storage.add(token_record) + + async def get_token( + self, + request: Request, + client_id: str, + token_type: Optional[str] = "refresh_token", + access_token: Optional[str] = None, + refresh_token: Optional[str] = None, + ) -> Optional[Token]: + if token_type == "refresh_token": + q = select(TokenDB).where(TokenDB.refresh_token == refresh_token) + else: + q = select(TokenDB).where(TokenDB.access_token == access_token) + + q_results = await self.storage.select( + q.where(TokenDB.revoked == False).options( # noqa + selectinload(TokenDB.user) + ) + ) + + token_record: Optional[TokenDB] + token_record = q_results.scalars().one_or_none() + + if token_record: + return Token( + access_token=token_record.access_token, + refresh_token=token_record.refresh_token, + scope=token_record.scope, + issued_at=token_record.issued_at, + expires_in=token_record.expires_in, + refresh_token_expires_in=token_record.refresh_token_expires_in, + client_id=client_id, + ) + + async def create_authorization_code( + self, + request: Request, + client_id: str, + scope: str, + response_type: ResponseType, + redirect_uri: str, + code_challenge_method: Optional[CodeChallengeMethod], + code_challenge: Optional[str], + code: str, + **kwargs, + ) -> AuthorizationCode: + authorization_code = AuthorizationCode( + auth_time=int(datetime.now(tz=timezone.utc).timestamp()), + client_id=client_id, + code=code, + code_challenge=code_challenge, + code_challenge_method=code_challenge_method, + expires_in=300, + redirect_uri=redirect_uri, + response_type=response_type, + scope=scope, + user=request.user, + ) + + authorization_code_record = AuthorizationCodeDB( + code=authorization_code.code, + client_id=authorization_code.client_id, + redirect_uri=authorization_code.redirect_uri, + response_type=authorization_code.response_type, + scope=authorization_code.scope, + auth_time=authorization_code.auth_time, + expires_in=authorization_code.expires_in, + code_challenge_method=authorization_code.code_challenge_method, + code_challenge=authorization_code.code_challenge, + nonce=authorization_code.nonce, + user_id=request.user.id, + ) + + await self.storage.add(authorization_code_record) + + return authorization_code + + async def get_client( + self, request: Request, client_id: str, client_secret: Optional[str] = None + ) -> Optional[Client]: + q_results = await self.storage.select( + select(ClientDB).where(ClientDB.client_id == client_id) + ) + + client_record: Optional[ClientDB] + client_record = q_results.scalars().one_or_none() + + if not client_record: + return None + + return Client( + client_id=client_record.client_id, + client_secret=client_record.client_secret, + grant_types=[client_record.grant_types], + response_types=[client_record.response_types], + redirect_uris=[client_record.redirect_uris], + scope=client_record.scope, + ) + + async def get_authorization_code( + self, request: Request, client_id: str, code: str + ) -> Optional[AuthorizationCode]: + q_results = await self.storage.select( + select(AuthorizationCodeDB).where(AuthorizationCodeDB.code == code) + ) + + authorization_code_record: Optional[AuthorizationCode] + authorization_code_record = q_results.scalars().one_or_none() + + if not authorization_code_record: + return None + + return AuthorizationCode( + code=authorization_code_record.code, + client_id=authorization_code_record.client_id, + redirect_uri=authorization_code_record.redirect_uri, + response_type=authorization_code_record.response_type, + scope=authorization_code_record.scope, + auth_time=authorization_code_record.auth_time, + expires_in=authorization_code_record.expires_in, + code_challenge=authorization_code_record.code_challenge, + code_challenge_method=authorization_code_record.code_challenge_method, + nonce=authorization_code_record.nonce, + ) + + async def delete_authorization_code( + self, request: Request, client_id: str, code: str + ) -> None: + await self.storage.delete( + delete(AuthorizationCodeDB).where(AuthorizationCodeDB.code == code) + ) + + async def get_id_token( + self, + request: Request, + client_id: str, + scope: str, + response_type: ResponseType, + redirect_uri: str, + **kwargs, + ) -> str: + scopes = enforce_list(scope) + user = await self.get_user(request) + user_data = {} + + if "email" in scopes: + user_data["email"] = user.username + user_data["username"] = user.username + + return encode_jwt( + expires_delta=settings.ACCESS_TOKEN_EXP, + sub=str(user.id), + additional_claims=user_data, + ) diff --git a/src/scheduler.py b/src/scheduler.py new file mode 100644 index 0000000..ad08020 --- /dev/null +++ b/src/scheduler.py @@ -0,0 +1,33 @@ +import contextlib +import datetime +from typing import TYPE_CHECKING + +import pytz +from apscheduler.schedulers.asyncio import AsyncIOScheduler + +if TYPE_CHECKING: + from src.telegram.enums import Message + +scheduler = AsyncIOScheduler(timezone="Asia/ShangHai") +if not scheduler.running: + scheduler.start() + + +async def delete_message(message: "Message") -> bool: + with contextlib.suppress(Exception): + await message.delete() + return True + return False + + +def add_delete_message_job(message: "Message", delete_seconds: int = 60): + scheduler.add_job( + delete_message, + "date", + id=f"{message.chat.id}|{message.id}|delete_message", + name=f"{message.chat.id}|{message.id}|delete_message", + args=[message], + run_date=datetime.datetime.now(pytz.timezone("Asia/Shanghai")) + + datetime.timedelta(seconds=delete_seconds), + replace_existing=True, + ) diff --git a/src/storage/__init__.py b/src/storage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/storage/models.py b/src/storage/models.py new file mode 100644 index 0000000..62108e5 --- /dev/null +++ b/src/storage/models.py @@ -0,0 +1,14 @@ +import uuid + +from pydantic.types import UUID4 +from sqlmodel import Field, SQLModel + + +class BaseTable(SQLModel): + id: UUID4 = Field( + primary_key=True, + default_factory=uuid.uuid4, + nullable=False, + index=True, + sa_column_kwargs={"unique": True}, + ) diff --git a/src/storage/sqlalchemy.py b/src/storage/sqlalchemy.py new file mode 100644 index 0000000..205405c --- /dev/null +++ b/src/storage/sqlalchemy.py @@ -0,0 +1,70 @@ +from typing import Optional + +from sqlalchemy.engine.result import Result +from sqlalchemy.sql.expression import Delete, Update +from sqlalchemy.sql.selectable import Select +from sqlmodel.ext.asyncio.session import AsyncSession + + +class SQLAlchemyTransaction: + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def __aenter__(self) -> "SQLAlchemyTransaction": + return self + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + if exc_type is None: + await self.commit() + else: + await self.rollback() + + await self.close() + + async def rollback(self): + await self.session.rollback() + + async def commit(self): + await self.session.commit() + + async def close(self): + await self.session.close() + + +class SQLAlchemyStorage: + def __init__( + self, session: AsyncSession, transaction: SQLAlchemyTransaction + ) -> None: + self.session = session + self.transaction = transaction + + async def select(self, q: Select) -> Result: + async with self.transaction: + return await self.session.execute(q) + + async def add(self, model) -> None: + async with self.transaction: + self.session.add(model) + + async def delete(self, q: Delete) -> None: + async with self.transaction: + await self.session.execute(q) + + async def update(self, q: Update): + async with self.transaction: + await self.session.execute(q) + + +sqlalchemy_session: Optional[AsyncSession] = None + + +def get_sqlalchemy_storage() -> SQLAlchemyStorage: + """Get SQLAlchemy storage instance. + + Returns: + SQLAlchemyStorage: SQLAlchemy storage instance + """ + sqllachemy_trancation = SQLAlchemyTransaction(session=sqlalchemy_session) + return SQLAlchemyStorage( + session=sqlalchemy_session, transaction=sqllachemy_trancation + ) diff --git a/src/telegram/__init__.py b/src/telegram/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/telegram/enums.py b/src/telegram/enums.py new file mode 100644 index 0000000..2cfe36c --- /dev/null +++ b/src/telegram/enums.py @@ -0,0 +1,25 @@ +from typing import Optional + +from pyrogram import Client as PyroClient +from pyrogram.types import Message as PyroMessage + + +class Client(PyroClient): # noqa + async def listen(self, chat_id, filters=None, timeout=None) -> Optional["Message"]: + return + + async def ask( + self, chat_id, text, filters=None, timeout=None, *args, **kwargs + ) -> Optional["Message"]: + return + + def cancel_listener(self, chat_id): + """Cancel the conversation with the given chat_id.""" + + +class Message(PyroMessage): # noqa + async def delay_delete(self, delete_seconds: int = 60) -> Optional[bool]: + return + + async def safe_delete(self, revoke: bool = True) -> None: + return diff --git a/src/telegram/filters.py b/src/telegram/filters.py new file mode 100644 index 0000000..aa89b8f --- /dev/null +++ b/src/telegram/filters.py @@ -0,0 +1,15 @@ +from typing import TYPE_CHECKING + +from pyrogram.filters import create + +from src.config import settings + +if TYPE_CHECKING: + from .enums import Message + + +async def admin_filter(_, __, m: "Message"): + return bool(m.from_user and m.from_user.id in settings.BOT_MANAGER_IDS) + + +admin = create(admin_filter) diff --git a/src/telegram/message.py b/src/telegram/message.py new file mode 100644 index 0000000..2b2952c --- /dev/null +++ b/src/telegram/message.py @@ -0,0 +1,8 @@ +import re + +NO_ACCOUNT_MSG = """UID `%s` 还没有注册账号,请联系管理员注册账号。""" +ACCOUNT_MSG = """UID: `%s`\n邮箱: `%s`""" +REG_MSG = """请发送需要使用的邮箱""" +MAIL_REGEX = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$") +LOGIN_MSG = """请点击下面的按钮登录:""" +LOGIN_BUTTON = """跳转登录""" diff --git a/src/telegram/plugins/__init__.py b/src/telegram/plugins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/telegram/plugins/account.py b/src/telegram/plugins/account.py new file mode 100644 index 0000000..24e1645 --- /dev/null +++ b/src/telegram/plugins/account.py @@ -0,0 +1,24 @@ +from pyrogram import filters + +from src.bot import bot +from src.config import settings +from src.telegram.enums import Client, Message +from src.telegram.message import ACCOUNT_MSG, NO_ACCOUNT_MSG +from src.users.crud import get_user_crud + + +async def account(message: Message, uid: int): + crud = get_user_crud() + user = await crud.get_by_tg_id(uid) + if user: + await message.reply(ACCOUNT_MSG % (user.tg_id, user.username), quote=True) + else: + await message.reply(NO_ACCOUNT_MSG % uid, quote=True) + + +@bot.on_message(filters=filters.private & filters.command("account")) +async def get_account(_: Client, message: Message): + uid = message.from_user.id + if uid in settings.BOT_MANAGER_IDS and len(message.command) >= 2 and message.command[1].isnumeric(): + uid = int(message.command[1]) + await account(message, uid) diff --git a/src/telegram/plugins/edit.py b/src/telegram/plugins/edit.py new file mode 100644 index 0000000..acc4e34 --- /dev/null +++ b/src/telegram/plugins/edit.py @@ -0,0 +1,44 @@ +from pyrogram import filters + +from pyromod.utils.errors import TimeoutConversationError +from src.bot import bot +from src.logs import logs +from src.telegram.enums import Client, Message +from src.telegram.filters import admin +from src.telegram.message import REG_MSG, MAIL_REGEX +from src.users.crud import get_user_crud + + +async def reg(client: Client, from_id: int, uid: int): + msg_ = await client.send_message(from_id, REG_MSG) + try: + msg = await client.listen(from_id, filters=filters.text, timeout=60) + except TimeoutConversationError: + await msg_.edit("响应超时,请重试") + return + if msg.text and MAIL_REGEX.match(msg.text): + crud = get_user_crud() + try: + user = await crud.get_by_tg_id(uid) + if user: + await crud.update(user, username=msg.text) + else: + await crud.create( + username=msg.text, + password="1", + tg_id=uid, + ) + except Exception as e: + logs.exception("注册失败", exc_info=e) + await msg.reply_text("注册失败") + await msg.reply_text("注册成功") + else: + await msg.reply_text("邮箱格式错误") + + +@bot.on_message(filters=filters.private & filters.command("edit") & admin) +async def edit_account(client: Client, message: Message): + uid = from_id = message.from_user.id + if len(message.command) >= 2 and message.command[1].isnumeric(): + uid = int(message.command[1]) + await reg(client, from_id, uid) diff --git a/src/telegram/plugins/start.py b/src/telegram/plugins/start.py new file mode 100644 index 0000000..090aea7 --- /dev/null +++ b/src/telegram/plugins/start.py @@ -0,0 +1,41 @@ +from httpx import URL +from pyrogram import filters, Client +from pyrogram.types import InlineKeyboardMarkup, InlineKeyboardButton + +from src.bot import bot +from src.config import settings +from src.telegram.enums import Message +from src.telegram.message import NO_ACCOUNT_MSG, LOGIN_MSG, LOGIN_BUTTON +from src.users.crud import get_user_crud +from src.utils.telegram import encode_telegram_auth_data + + +async def login(message: Message): + uid = message.from_user.id + crud = get_user_crud() + user = await crud.get_by_tg_id(uid) + if not user: + await message.reply(NO_ACCOUNT_MSG % uid, quote=True) + return + token = await encode_telegram_auth_data(uid) + url = settings.PROJECT_URL + "/api/users/auth" + url = URL(url).copy_add_param("jwt", token) + url = str(url) + await message.reply( + LOGIN_MSG, + quote=True, + reply_markup=InlineKeyboardMarkup( + [[InlineKeyboardButton(LOGIN_BUTTON, url=url)]] + ), + ) + + +@bot.on_message(filters=filters.private & filters.command("start")) +async def start(client: Client, message: Message): + if message.command and len(message.command) >= 2: + action = message.command[1] + if action == "login": + await login(message) + return + me = await client.get_me() + await message.reply(f"Hello, I'm {me.first_name}. ") diff --git a/src/users/__init__.py b/src/users/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/users/backends.py b/src/users/backends.py new file mode 100644 index 0000000..8fee728 --- /dev/null +++ b/src/users/backends.py @@ -0,0 +1,32 @@ +from fastapi.security.utils import get_authorization_scheme_param +from starlette.authentication import AuthCredentials, AuthenticationBackend + +from .crypto import authenticate, read_rsa_key_from_env +from .models import User, UserAnonymous +from ..config import settings + + +class TokenAuthenticationBackend(AuthenticationBackend): + async def authenticate(self, request): + authorization: str = request.headers.get("Authorization") + _, bearer_token = get_authorization_scheme_param(authorization) + + token: str = request.cookies.get("access_token") or bearer_token + + if not token: + return AuthCredentials(), UserAnonymous() + + key = read_rsa_key_from_env(settings.JWT_PUBLIC_KEY) + + is_authenticated, decoded_token = authenticate(token=token, key=key) + + if is_authenticated: + return AuthCredentials(), User( + id=decoded_token["sub"], + is_superuser=decoded_token["is_superuser"], + is_blocked=decoded_token["is_blocked"], + is_active=decoded_token["is_active"], + username=decoded_token["username"], + ) + + return AuthCredentials(), UserAnonymous() diff --git a/src/users/crud.py b/src/users/crud.py new file mode 100644 index 0000000..b3a61f6 --- /dev/null +++ b/src/users/crud.py @@ -0,0 +1,40 @@ +from typing import Optional + +from sqlalchemy.future import select +from sqlalchemy.orm import selectinload +from sqlalchemy.sql.expression import Update + +from .models import User +from ..storage.sqlalchemy import SQLAlchemyStorage, get_sqlalchemy_storage + + +class SQLAlchemyCRUD: + def __init__(self, storage: SQLAlchemyStorage): + self.storage = storage + + async def get_by_tg_id(self, tg_id: int) -> Optional[User]: + q_results = await self.storage.select( + select(User) + .options( + # for relationship loading, eager loading should be applied. + selectinload(User.user_tokens) + ) + .where(User.tg_id == tg_id) + ) + + return q_results.scalars().one_or_none() + + async def create(self, **kwargs) -> None: + user = User(**kwargs) + await self.storage.add(user) + + async def update(self, user: User, **kwargs) -> None: + await self.storage.update( + Update(User).where(User.id == user.id).values(**kwargs) + ) + + +def get_user_crud(storage: SQLAlchemyStorage = None) -> SQLAlchemyCRUD: + if storage is None: + storage = get_sqlalchemy_storage() + return SQLAlchemyCRUD(storage=storage) diff --git a/src/users/crypto.py b/src/users/crypto.py new file mode 100644 index 0000000..84b0724 --- /dev/null +++ b/src/users/crypto.py @@ -0,0 +1,166 @@ +import base64 +import pathlib +import re +import string +import uuid +from datetime import datetime, timedelta, timezone +from typing import Dict, Tuple + +from Crypto.PublicKey import RSA +from jose import constants, jwt +from jose.exceptions import JWTError + +from ..config import settings + +RANDOM_STRING_CHARS = string.ascii_lowercase + string.ascii_uppercase + string.digits +KEYS = {} + + +def reformat_rsa_key(rsa_key: str) -> str: + """Reformat an RSA PEM key without newlines to one with correct newline characters + + @param rsa_key: the PEM RSA key lacking newline characters + @return: the reformatted PEM RSA key with appropriate newline characters + """ + # split headers from the body + split_rsa_key = re.split(r"(-+)", rsa_key) + + # add newlines between headers and body + split_rsa_key.insert(4, "\n") + split_rsa_key.insert(6, "\n") + + reformatted_rsa_key = "".join(split_rsa_key) + + # reformat body + return RSA.importKey(reformatted_rsa_key).exportKey().decode("utf-8") + + +def read_rsa_key_from_env(file_path: str) -> str: + if file_path in KEYS: + return KEYS[file_path] + path = pathlib.Path(file_path) + + # path to rsa key file + if path.is_file(): + with open(file_path, "rb") as key_file: + jwt_private_key = RSA.importKey(key_file.read()).exportKey() + k = jwt_private_key.decode("utf-8") + KEYS[file_path] = k + return k + + # rsa key without newlines + if "\n" not in file_path: + k = reformat_rsa_key(file_path) + KEYS[file_path] = k + return k + + return file_path + + +def get_n(rsa: RSA): + bytes_data = rsa.n.to_bytes((rsa.n.bit_length() + 7) // 8, 'big') + return base64.urlsafe_b64encode(bytes_data).decode('utf-8') + + +def get_pub_key_resp(): + pub_key = RSA.importKey(read_rsa_key_from_env(settings.JWT_PUBLIC_KEY)) + return { + "keys": [ + { + "n": get_n(pub_key), + "kty": "RSA", + "alg": "RS256", + "kid": "sig", + "e": "AQAB", + "use": "sig" + } + ] + } + + +def encode_jwt( + expires_delta, + sub, + secret=None, + additional_claims=None, + algorithm=constants.ALGORITHMS.RS256, +): + if additional_claims is None: + additional_claims = {} + if secret is None: + secret = read_rsa_key_from_env(settings.JWT_PRIVATE_KEY) + now = datetime.now(timezone.utc) + + claims = { + "iat": now, + "jti": str(uuid.uuid4()), + "nbf": now, + "sub": sub, + "exp": now + timedelta(seconds=expires_delta), + **additional_claims, + } + + return jwt.encode( + claims, + secret, + algorithm, + ) + + +def decode_jwt( + encoded_token, + secret=None, + algorithms=None, +): + if algorithms is None: + algorithms = constants.ALGORITHMS.RS256 + if secret is None: + secret = read_rsa_key_from_env(settings.JWT_PRIVATE_KEY) + return jwt.decode( + encoded_token, + secret, + algorithms=algorithms, + ) + + +def get_jwt(user): + access_token = encode_jwt( + sub=str(user.id), + expires_delta=settings.ACCESS_TOKEN_EXP, + additional_claims={ + "token_type": "access", + "is_blocked": user.is_blocked, + "is_superuser": user.is_superuser, + "username": user.username, + "is_active": user.is_active, + }, + ) + + refresh_token = encode_jwt( + sub=str(user.id), + expires_delta=settings.REFRESH_TOKEN_EXP, + additional_claims={ + "token_type": "refresh", + "is_blocked": user.is_blocked, + "is_superuser": user.is_superuser, + "username": user.username, + "is_active": user.is_active, + }, + ) + + return access_token, refresh_token + + +def authenticate( + *, + token: str, + key: str, +) -> Tuple[bool, Dict]: + """Authenticate user by token""" + try: + token_header = jwt.get_unverified_header(token) + decoded_token = jwt.decode(token, key, algorithms=token_header.get("alg")) + except JWTError: + return False, {} + else: + return True, decoded_token diff --git a/src/users/endpoints.py b/src/users/endpoints.py new file mode 100644 index 0000000..c4e4a61 --- /dev/null +++ b/src/users/endpoints.py @@ -0,0 +1,79 @@ +from http import HTTPStatus + +from fastapi import APIRouter, Depends, HTTPException +from jose import JWTError +from starlette.requests import Request + +from .crud import SQLAlchemyCRUD +from .crypto import get_jwt +from ..config import settings +from ..html import templates +from ..storage.sqlalchemy import SQLAlchemyStorage, get_sqlalchemy_storage +from ..utils.oauth import back_auth_request +from ..utils.redirect import RedirectResponseBuilder +from ..utils.telegram import decode_telegram_auth_data, verify_telegram_auth_data + +router = APIRouter() + + +@router.get("/login", name="users:login:get") +async def user_login_get(request: Request): + if request.user.is_authenticated: + if resp := await back_auth_request(request): + return resp + return RedirectResponseBuilder().build(settings.PROJECT_LOGIN_SUCCESS_URL) + url = request.url + callback_url = str(url).replace("/login", "/callback") + return templates.TemplateResponse( + "login.jinja", + {"request": request, "callback_url": callback_url, "username": settings.BOT_USERNAME} + ) + + +async def auth( + tg_id: int, + request: Request, + storage: SQLAlchemyStorage, +): + if tg_id is None: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED) + crud = SQLAlchemyCRUD(storage=storage) + user = await crud.get_by_tg_id(tg_id=tg_id) + + if user is None: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED) + + access_token, refresh_token = get_jwt(user) + # NOTE: Setting expire causes an exception for requests library: + # https://github.com/psf/requests/issues/6004 + if resp := await back_auth_request(request, access_token, refresh_token): + return resp + resp = RedirectResponseBuilder() + resp.set_cookie( + key="access_token", value=access_token, max_age=settings.ACCESS_TOKEN_EXP + ) + resp.set_cookie( + key="refresh_token", value=refresh_token, max_age=settings.REFRESH_TOKEN_EXP + ) + return resp.build(settings.PROJECT_LOGIN_SUCCESS_URL) + + +@router.get("/callback", name="users:login") +async def user_login( + request: Request, + storage: SQLAlchemyStorage = Depends(get_sqlalchemy_storage), +): + tg_id = await verify_telegram_auth_data(request.query_params) + return await auth(tg_id, request, storage) + + +@router.get("/auth", name="users:auth") +async def user_auth( + request: Request, + storage: SQLAlchemyStorage = Depends(get_sqlalchemy_storage), +): + try: + tg_id = await decode_telegram_auth_data(request.query_params) + except JWTError: + tg_id = None + return await auth(tg_id, request, storage) diff --git a/src/users/models.py b/src/users/models.py new file mode 100644 index 0000000..daf7643 --- /dev/null +++ b/src/users/models.py @@ -0,0 +1,37 @@ +from typing import TYPE_CHECKING, List, Optional + +from pydantic import BaseModel +from sqlalchemy import Column, BigInteger +from sqlmodel.main import Field, Relationship + +from ..storage.models import BaseTable + +if TYPE_CHECKING: # pragma: no cover + from ..oauth2.models import AuthorizationCode, Token + + +class UserAnonymous(BaseModel): + @property + def is_authenticated(self) -> bool: + return False + + +class User(BaseTable, table=True): # type: ignore + __tablename__ = "users" + + is_superuser: bool = False + is_blocked: bool = False + is_active: bool = False + + username: str = Field(nullable=False, sa_column_kwargs={"unique": True}, index=True) + password: Optional[str] = None + tg_id: int = Field(sa_column=Column(BigInteger(), nullable=False)) + + user_authorization_codes: List["AuthorizationCode"] = Relationship( + back_populates="user" + ) + user_tokens: List["Token"] = Relationship(back_populates="user") + + @property + def is_authenticated(self) -> bool: + return True diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/oauth.py b/src/utils/oauth.py new file mode 100644 index 0000000..bafe9b2 --- /dev/null +++ b/src/utils/oauth.py @@ -0,0 +1,38 @@ +from typing import Optional + +from starlette.requests import Request +from starlette.responses import RedirectResponse + +from src.config import settings +from src.users.crypto import encode_jwt, decode_jwt +from src.utils.redirect import RedirectResponseBuilder + + +async def to_login_request(request: Request) -> RedirectResponse: + query_params = dict(request.query_params) + params = "" + for key, value in query_params.items(): + params += f"{key}={value}&" + params = params[:-1] + jwt = encode_jwt(settings.ACCESS_TOKEN_EXP, "", additional_claims={"params": params}) + resp = RedirectResponseBuilder() + resp.set_cookie("SEND", jwt, max_age=settings.ACCESS_TOKEN_EXP) + return resp.build("/api/users/login") + + +async def back_auth_request( + request: Request, + access_token: str = None, + refresh_token: str = None, +) -> Optional[RedirectResponse]: + cookie = request.cookies.get("SEND") + if cookie is None: + return None + params = decode_jwt(cookie)["params"] + resp = RedirectResponseBuilder() + if access_token: + resp.set_cookie("access_token", access_token, max_age=settings.ACCESS_TOKEN_EXP) + if refresh_token: + resp.set_cookie("refresh_token", refresh_token, max_age=settings.ACCESS_TOKEN_EXP) + resp.delete_cookie("SEND") + return resp.build(f"/oauth2/authorize?{params}", status_code=303) diff --git a/src/utils/redirect.py b/src/utils/redirect.py new file mode 100644 index 0000000..10af040 --- /dev/null +++ b/src/utils/redirect.py @@ -0,0 +1,78 @@ +import http.cookies +import typing +from datetime import datetime +from email.utils import format_datetime + +from starlette.datastructures import MutableHeaders +from starlette.responses import RedirectResponse + + +class RedirectResponseBuilder: + def __init__(self): + self.raw_headers = [] + + def set_cookie( + self, + key: str, + value: str = "", + max_age: typing.Optional[int] = None, + expires: typing.Optional[typing.Union[datetime, str, int]] = None, + path: str = "/", + domain: typing.Optional[str] = None, + secure: bool = False, + httponly: bool = False, + samesite: typing.Optional[typing.Literal["lax", "strict", "none"]] = "lax", + ) -> None: + cookie: "http.cookies.BaseCookie[str]" = http.cookies.SimpleCookie() + cookie[key] = value + if max_age is not None: + cookie[key]["max-age"] = max_age + if expires is not None: + if isinstance(expires, datetime): + cookie[key]["expires"] = format_datetime(expires, usegmt=True) + else: + cookie[key]["expires"] = expires + if path is not None: + cookie[key]["path"] = path + if domain is not None: + cookie[key]["domain"] = domain + if secure: + cookie[key]["secure"] = True + if httponly: + cookie[key]["httponly"] = True + if samesite is not None: + assert samesite.lower() in [ + "strict", + "lax", + "none", + ], "samesite must be either 'strict', 'lax' or 'none'" + cookie[key]["samesite"] = samesite + cookie_val = cookie.output(header="").strip() + self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1"))) + + def delete_cookie( + self, + key: str, + path: str = "/", + domain: typing.Optional[str] = None, + secure: bool = False, + httponly: bool = False, + samesite: typing.Optional[typing.Literal["lax", "strict", "none"]] = "lax", + ) -> None: + self.set_cookie( + key, + max_age=0, + expires=0, + path=path, + domain=domain, + secure=secure, + httponly=httponly, + samesite=samesite, + ) + + @property + def headers(self) -> MutableHeaders: + return MutableHeaders(raw=self.raw_headers) + + def build(self, url: str, status_code: int = 307): + return RedirectResponse(url, headers=self.headers, status_code=status_code) diff --git a/src/utils/telegram.py b/src/utils/telegram.py new file mode 100644 index 0000000..f5c0ddf --- /dev/null +++ b/src/utils/telegram.py @@ -0,0 +1,45 @@ +import hashlib +import hmac +from datetime import datetime, timezone +from typing import Optional + +from starlette.datastructures import QueryParams + +from src.config import settings +from src.users.crypto import encode_jwt, decode_jwt + + +async def verify_telegram_auth_data(params: QueryParams) -> Optional[int]: + data = list(params.items()) + hash_str = "" + text_list = [] + for key, value in data: + if key == "hash": + hash_str = value + else: + text_list.append(f"{key}={value}") + check_string = "\n".join(sorted(text_list)) + + secret_key = hashlib.sha256(str.encode(settings.BOT_TOKEN)).digest() + hmac_hash = hmac.new(secret_key, str.encode(check_string), hashlib.sha256).hexdigest() + + return int(params.get("id")) if hmac_hash == hash_str else None + + +async def encode_telegram_auth_data(uid: int) -> str: + jwt = encode_jwt(settings.ACCESS_TOKEN_EXP, str(uid)) + return jwt + + +async def decode_telegram_auth_data(params: QueryParams) -> Optional[int]: + jwt = params.get("jwt") + if not jwt: + return None + if not jwt: + return None + data = decode_jwt(jwt) + now = datetime.now(timezone.utc) + uid, exp = data["sub"], data["exp"] + if exp < now.timestamp(): + return None + return int(uid)