add guesslang tf model
This commit is contained in:
parent
1ca24c58e5
commit
4ee081df66
|
@ -1 +1,2 @@
|
|||
/target
|
||||
venv
|
File diff suppressed because it is too large
Load Diff
|
@ -10,6 +10,7 @@ actix-web = "4"
|
|||
silicon = { git = "https://github.com/watzon/silicon.git" }
|
||||
lazy_static = "1.4.0"
|
||||
serde = { version = "1.0.130", features = ["derive"] }
|
||||
serde_json = "1.0.107"
|
||||
structopt = "0.3.26"
|
||||
image = "0.24.7"
|
||||
anyhow = "1.0.75"
|
||||
|
@ -18,4 +19,6 @@ syntect = "5.1.0"
|
|||
font-kit = "0.11.0"
|
||||
reqwest = "0.11.22"
|
||||
hyperpolyglot = "0.1.7"
|
||||
tempfile = "3.8.0"
|
||||
tempfile = "3.8.0"
|
||||
tensorflow = "0.17.0"
|
||||
clap = { version = "4.4.7", features = ["derive"] }
|
|
@ -1,7 +1,44 @@
|
|||
FROM debian:buster-slim as tensorflow
|
||||
|
||||
WORKDIR /usr/src/build
|
||||
|
||||
# Install dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
git \
|
||||
wget \
|
||||
gnupg \
|
||||
python3 \
|
||||
python3-dev \
|
||||
python3-pip \
|
||||
python3-numpy \
|
||||
llvm \
|
||||
clang
|
||||
|
||||
RUN pip3 install wheel packaging requests opt_einsum
|
||||
RUN pip3 install keras_preprocessing --no-deps
|
||||
|
||||
# Install bazel
|
||||
RUN wget https://github.com/bazelbuild/bazelisk/releases/download/v1.18.0/bazelisk-linux-amd64
|
||||
RUN chmod +x bazelisk-linux-amd64
|
||||
RUN mv bazelisk-linux-amd64 /usr/local/bin/bazel
|
||||
|
||||
# Install tensorflow
|
||||
RUN git clone https://github.com/tensorflow/tensorflow \
|
||||
&& cd tensorflow \
|
||||
&& git checkout v2.5.0
|
||||
RUN cd tensorflow && ./configure
|
||||
RUN cd tensorflow && bazel build --compilation_mode=opt --copt=-march=native --jobs=12 tensorflow:libtensorflow.so
|
||||
|
||||
FROM rust:1.73.0-buster as builder
|
||||
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
# Copy tensorflow shared libraries from tensorflow image
|
||||
COPY --from=tensorflow /usr/src/build/tensorflow/bazel-bin/tensorflow/libtensorflow.so* /usr/local/lib/
|
||||
COPY --from=tensorflow /usr/src/build/tensorflow/bazel-bin/tensorflow/libtensorflow_framework.so* /usr/local/lib/
|
||||
|
||||
RUN ldconfig
|
||||
|
||||
# Install dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libssl-dev \
|
||||
|
@ -30,6 +67,8 @@ RUN bash ./download_nerd_fonts.sh
|
|||
|
||||
FROM debian:buster-slim
|
||||
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
# Install dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libssl-dev \
|
||||
|
@ -41,8 +80,15 @@ RUN apt-get update && apt-get install -y \
|
|||
COPY --from=fonts /data/fonts/nerd_fonts/* /usr/share/fonts/truetype/
|
||||
RUN fc-cache -fv
|
||||
|
||||
# Copy binary
|
||||
COPY --from=builder /usr/src/app/target/release/inkify /usr/local/bin/inkify
|
||||
# Copy binary abd tensorflow model files
|
||||
COPY --from=builder /usr/src/app/target/release/inkify /usr/src/app/
|
||||
COPY --from=builder /usr/src/app/src/tensorflow /usr/src/app/tensorflow/
|
||||
|
||||
# Copy tensorflow shared libraries from tensorflow image
|
||||
COPY --from=tensorflow /usr/src/build/tensorflow/bazel-bin/tensorflow/libtensorflow.so* /usr/local/lib/
|
||||
COPY --from=tensorflow /usr/src/build/tensorflow/bazel-bin/tensorflow/libtensorflow_framework.so* /usr/local/lib/
|
||||
|
||||
RUN ldconfig
|
||||
|
||||
ARG PORT=8080
|
||||
ARG HOST=0.0.0.0
|
||||
|
@ -53,4 +99,4 @@ ENV HOST=$HOST
|
|||
EXPOSE $PORT
|
||||
|
||||
# Run
|
||||
ENTRYPOINT ["/usr/local/bin/inkify"]
|
||||
CMD ["/usr/src/app/inkify", "--tensorflow-model-dir", "/usr/src/app/tensorflow"]
|
|
@ -1,17 +1,17 @@
|
|||
use anyhow::Error;
|
||||
use silicon::formatter::{ImageFormatter, ImageFormatterBuilder};
|
||||
use silicon::utils::{Background, ShadowAdder};
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use syntect::highlighting::{Theme, ThemeSet};
|
||||
use syntect::parsing::{SyntaxReference, SyntaxSet};
|
||||
use tensorflow::{Graph, SavedModelBundle, SessionOptions, Tensor};
|
||||
|
||||
use crate::rgba::{ImageRgba, Rgba};
|
||||
|
||||
type FontList = Vec<(String, f32)>;
|
||||
type Lines = Vec<u32>;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
/// Background image URL
|
||||
pub background_image: Option<Vec<u8>>,
|
||||
|
@ -72,6 +72,12 @@ pub struct Config {
|
|||
|
||||
/// The syntax highlight theme. It can be a theme name or path to a .tmTheme file.
|
||||
pub theme: String,
|
||||
|
||||
#[serde(skip_deserializing)]
|
||||
pub tf_model_graph: Option<Graph>,
|
||||
|
||||
#[serde(skip_deserializing)]
|
||||
pub tf_model: Option<SavedModelBundle>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
|
@ -97,9 +103,25 @@ impl Config {
|
|||
shadow_offset_x: 0,
|
||||
tab_width: 4,
|
||||
theme: "Dracula".to_owned(),
|
||||
tf_model_graph: None,
|
||||
tf_model: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_tensorflow_model(&mut self, export_dir: &str) {
|
||||
let mut graph = Graph::new();
|
||||
let model = match SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir) {
|
||||
Ok(model) => model,
|
||||
Err(e) => {
|
||||
eprintln!("Failed to load TensorFlow model: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
self.tf_model = Some(model);
|
||||
self.tf_model_graph = Some(graph);
|
||||
}
|
||||
|
||||
pub fn language<'a>(&self, ps: &'a SyntaxSet) -> Result<&'a SyntaxReference, Error> {
|
||||
let language = match &self.language {
|
||||
Some(language) => ps
|
||||
|
@ -108,20 +130,51 @@ impl Config {
|
|||
None => {
|
||||
let first_line = self.code.lines().next().unwrap_or_default();
|
||||
ps.find_syntax_by_first_line(first_line).unwrap_or_else(|| {
|
||||
// hyperpolyglot requires a file, so we need to create a temp file
|
||||
let mut temp_file = tempfile::NamedTempFile::new().unwrap();
|
||||
write!(temp_file, "{}", self.code).unwrap();
|
||||
let language = hyperpolyglot::detect(temp_file.path()).unwrap();
|
||||
match language {
|
||||
Some(language) => ps.find_syntax_by_token(language.language()).unwrap(),
|
||||
None => ps.find_syntax_by_token("log").unwrap(),
|
||||
}
|
||||
// Try using tensorflow to detect the language
|
||||
let input_data = Tensor::new(&[1]).with_values(&[self.code.clone()]).unwrap();
|
||||
self.predict_language_with_tensorflow(ps, input_data)
|
||||
.unwrap_or_else(|_| ps.find_syntax_by_token("log").unwrap())
|
||||
})
|
||||
},
|
||||
};
|
||||
Ok(language)
|
||||
}
|
||||
|
||||
pub fn predict_language_with_tensorflow<'a>(&self, ps: &'a SyntaxSet, input_data: Tensor<String>) -> Result<&'a SyntaxReference, Error> {
|
||||
if self.tf_model_graph.is_none() || self.tf_model.is_none() {
|
||||
return Err(Error::msg("TensorFlow model not loaded"));
|
||||
}
|
||||
|
||||
let graph = self.tf_model_graph.as_ref().unwrap();
|
||||
let model = self.tf_model.as_ref().unwrap();
|
||||
let mut args = tensorflow::SessionRunArgs::new();
|
||||
|
||||
let input_tensor = graph.operation_by_name_required("Placeholder")?;
|
||||
|
||||
let output_tensor_scores = graph.operation_by_name_required("head/predictions/probabilities")?;
|
||||
|
||||
let output_tensor_classes = graph.operation_by_name_required("head/Tile")?;
|
||||
|
||||
args.add_feed(&input_tensor, 0, &input_data);
|
||||
let output_token_scores = args.request_fetch(&output_tensor_scores, 0);
|
||||
let output_token_classes = args.request_fetch(&output_tensor_classes, 0);
|
||||
|
||||
model.session.run(&mut args)?;
|
||||
|
||||
let scores: Tensor<f32> = args.fetch(output_token_scores)?;
|
||||
|
||||
let classes: Tensor<String> = args.fetch(output_token_classes)?;
|
||||
|
||||
// Find the index of the highest score
|
||||
let max_index = scores.iter().enumerate().max_by(|a, b| a.1.partial_cmp(b.1).unwrap()).unwrap().0;
|
||||
|
||||
|
||||
let language = &classes[max_index];
|
||||
let language = ps.find_syntax_by_token(language).unwrap();
|
||||
|
||||
Ok(language)
|
||||
}
|
||||
|
||||
|
||||
pub fn theme(&self, ts: &ThemeSet) -> Result<Theme, Error> {
|
||||
if let Some(theme) = ts.themes.get(&self.theme) {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#[macro_use]
|
||||
extern crate anyhow;
|
||||
|
||||
use clap::Parser;
|
||||
use actix_web::{get, web, App, HttpResponse, HttpServer, Responder};
|
||||
use anyhow::Error;
|
||||
use lazy_static::lazy_static;
|
||||
|
@ -28,6 +29,13 @@ macro_rules! unwrap_or_return {
|
|||
};
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct CliArgs {
|
||||
#[arg(short, long)]
|
||||
tensorflow_model_dir: Option<String>,
|
||||
}
|
||||
|
||||
fn parse_font_str(s: &str) -> Vec<(String, f32)> {
|
||||
let mut result = vec![];
|
||||
for font in s.split(';') {
|
||||
|
@ -144,6 +152,7 @@ async fn fonts() -> impl Responder {
|
|||
|
||||
#[get("/generate")]
|
||||
async fn generate(info: web::Query<config::ConfigQuery>) -> impl Responder {
|
||||
let args = CliArgs::parse();
|
||||
let ha = &*HIGHLIGHTING_ASSETS;
|
||||
|
||||
let (ps, ts) = (&ha.syntax_set, &ha.theme_set);
|
||||
|
@ -156,6 +165,10 @@ async fn generate(info: web::Query<config::ConfigQuery>) -> impl Responder {
|
|||
.body(r#"{"error": "code parameter is required"}"#);
|
||||
}
|
||||
|
||||
if args.tensorflow_model_dir.is_some() {
|
||||
conf.load_tensorflow_model(args.tensorflow_model_dir.unwrap().as_str());
|
||||
}
|
||||
|
||||
conf.language = info.language.clone();
|
||||
if let Some(theme) = info.theme.clone() {
|
||||
conf.theme = theme;
|
||||
|
@ -227,6 +240,7 @@ async fn generate(info: web::Query<config::ConfigQuery>) -> impl Responder {
|
|||
.append_header(("Content-Type", "application/json"))
|
||||
.body(r#"{"error": "Unable to determine language, please provide one explicitly"}"#)
|
||||
);
|
||||
|
||||
let theme = unwrap_or_return!(
|
||||
conf.theme(ts),
|
||||
HttpResponse::BadRequest()
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue